From b07dc44f042b2c4c0c80be3420f6d896044742d9 Mon Sep 17 00:00:00 2001 From: wuweikang Date: Sat, 29 Aug 2020 15:35:03 +0800 Subject: [PATCH] code sync for runpackage C75B150 --- inc/common/optimizer/graph_optimizer.h | 3 + .../ai_core/common/aicore_util_attr_define.h | 39 + .../util/ai_core/common/aicore_util_types.h | 118 ++ inc/common/util/ai_core/common/graph_comm.h | 107 ++ .../util/ai_core/common/scope_allocator.h | 42 + .../param_calculate/aicore_param_calculator.h | 23 +- .../param_calculate/tensorsize_calculator.h | 45 + inc/external/ge/ge_api.h | 9 + inc/external/ge/ge_ir_build.h | 18 + inc/external/graph/operator.h | 10 +- inc/external/graph/operator_reg.h | 8 + .../scope/scope_fusion_pass_register.h | 331 +++++ inc/framework/common/ge_types.h | 33 +- inc/framework/common/types.h | 1 + inc/framework/common/util.h | 8 + inc/framework/engine/dnnengine.h | 1 + inc/framework/executor/ge_executor.h | 12 + inc/framework/memory/memory_api.h | 1 + inc/framework/omg/omg.h | 4 + inc/graph/debug/ge_attr_define.h | 5 +- inc/graph/ge_context.h | 1 + inc/graph/op_desc.h | 3 - inc/graph/utils/node_utils.h | 30 +- inc/graph/utils/type_utils.h | 1 + src/common/graph/ge_attr_define.cc | 6 +- src/common/graph/op_desc.cc | 12 - src/common/graph/operator.cc | 34 + src/common/graph/option/ge_context.cc | 12 + src/common/graph/ref_relation.cc | 4 +- src/common/graph/shape_refiner.cc | 92 +- src/common/graph/utils/node_utils.cc | 221 ++- src/common/graph/utils/op_desc_utils.cc | 5 + src/common/graph/utils/type_utils.cc | 15 + src/ge/CMakeLists.txt | 18 + src/ge/client/ge_api.cc | 16 + src/ge/common/dump/dump_manager.cc | 120 ++ src/ge/common/dump/dump_manager.h | 42 + src/ge/common/dump/dump_op.cc | 255 ++++ src/ge/common/dump/dump_op.h | 61 + src/ge/common/dump/dump_properties.cc | 238 ++++ src/ge/common/dump/dump_properties.h | 86 ++ src/ge/common/ge/datatype_util.cc | 58 +- src/ge/common/ge/datatype_util.h | 9 +- src/ge/common/ge/tbe_plugin_manager.cc | 4 +- src/ge/common/ge_common.mk | 1 - src/ge/common/math/math_util.h | 20 +- .../common/model_parser/graph_parser_util.cc | 501 ------- .../common/model_parser/graph_parser_util.h | 62 - src/ge/common/profiling/profiling_manager.cc | 6 +- src/ge/common/properties_manager.cc | 187 --- src/ge/common/properties_manager.h | 45 +- src/ge/common/types.cc | 1 + src/ge/common/util.cc | 20 +- src/ge/engine_manager/dnnengine_manager.cc | 19 + src/ge/engine_manager/dnnengine_manager.h | 1 + src/ge/engine_manager/engine_conf.json | 7 + src/ge/executor/CMakeLists.txt | 3 + src/ge/executor/ge_executor.cc | 93 +- src/ge/executor/module.mk | 3 + src/ge/ge_inference.mk | 12 +- src/ge/ge_runner.mk | 25 +- src/ge/generator/ge_generator.cc | 3 + src/ge/graph/build/graph_builder.cc | 17 +- src/ge/graph/build/graph_builder.h | 4 +- .../graph/build/memory/block_mem_assigner.cc | 36 +- .../graph/build/memory/block_mem_assigner.h | 2 + .../graph/build/memory/graph_mem_assigner.cc | 41 +- .../graph/build/memory/graph_mem_assigner.h | 2 + src/ge/graph/build/model_builder.cc | 28 +- src/ge/graph/build/stream_allocator.cc | 56 +- src/ge/graph/build/stream_allocator.h | 3 +- src/ge/graph/build/task_generator.cc | 13 +- src/ge/graph/common/transop_util.cc | 20 + src/ge/graph/common/transop_util.h | 2 + src/ge/graph/execute/graph_execute.cc | 33 +- src/ge/graph/execute/graph_execute.h | 11 + .../load/new_model_manager/aipp_utils.cc | 4 + .../load/new_model_manager/data_dumper.cc | 97 +- .../load/new_model_manager/data_dumper.h | 9 + .../load/new_model_manager/davinci_model.cc | 272 ++-- .../load/new_model_manager/davinci_model.h | 60 +- .../load/new_model_manager/model_manager.cc | 42 +- .../load/new_model_manager/model_manager.h | 10 + .../task_info/hccl_task_info.cc | 57 +- .../task_info/hccl_task_info.h | 5 +- .../task_info/kernel_task_info.cc | 45 +- .../task_info/kernel_task_info.h | 1 + .../task_info/memcpy_addr_async_task_info.cc | 7 +- .../task_info/memcpy_async_task_info.cc | 48 +- .../task_info/memcpy_async_task_info.h | 13 +- .../task_info/stream_switchn_task_info.cc | 48 +- .../task_info/stream_switchn_task_info.h | 1 + .../load/new_model_manager/zero_copy_task.cc | 5 +- .../load/new_model_manager/zero_copy_task.h | 5 +- src/ge/graph/manager/graph_manager.cc | 433 +++--- src/ge/graph/manager/graph_manager.h | 13 +- src/ge/graph/manager/graph_var_manager.cc | 44 + src/ge/graph/manager/graph_var_manager.h | 6 + src/ge/graph/manager/rdma_pool_allocator.cc | 1 - src/ge/graph/manager/rdma_pool_allocator.h | 3 +- src/ge/graph/manager/util/rt_context_util.cc | 4 + src/ge/graph/optimize/graph_optimize.cc | 49 +- src/ge/graph/optimize/graph_optimize.h | 8 +- .../optimize/mem_rw_conflict_optimize.cc | 306 ++-- .../partition/dynamic_shape_partition.cc | 94 +- .../graph/partition/dynamic_shape_partition.h | 18 +- src/ge/graph/passes/assign_pass.cc | 133 ++ .../{switch_fusion_pass.h => assign_pass.h} | 26 +- .../graph/passes/attach_stream_label_pass.cc | 71 +- .../graph/passes/attach_stream_label_pass.h | 4 +- src/ge/graph/passes/cast_remove_pass.cc | 15 +- src/ge/graph/passes/cast_remove_pass.h | 1 + src/ge/graph/passes/cond_pass.cc | 12 +- src/ge/graph/passes/flow_ctrl_pass.cc | 10 +- src/ge/graph/passes/folding_pass.cc | 2 +- src/ge/graph/passes/folding_pass.h | 4 +- .../graph/passes/global_step_insert_pass.cc | 99 ++ src/ge/graph/passes/global_step_insert_pass.h | 57 + src/ge/graph/passes/hccl_memcpy_pass.cc | 24 +- src/ge/graph/passes/hccl_memcpy_pass.h | 2 +- src/ge/graph/passes/identity_pass.cc | 11 +- src/ge/graph/passes/infershape_pass.cc | 2 - src/ge/graph/passes/memcpy_addr_async_pass.cc | 36 +- src/ge/graph/passes/memcpy_addr_async_pass.h | 1 + .../passes/merge_to_stream_merge_pass.cc | 10 + src/ge/graph/passes/multi_batch_clone_pass.cc | 610 ++++++++ src/ge/graph/passes/multi_batch_clone_pass.h | 155 +++ src/ge/graph/passes/multi_batch_pass.cc | 150 +- src/ge/graph/passes/multi_batch_pass.h | 12 +- src/ge/graph/passes/net_output_pass.cc | 58 +- src/ge/graph/passes/next_iteration_pass.cc | 254 ++-- src/ge/graph/passes/next_iteration_pass.h | 17 +- .../passes/replace_with_empty_const_pass.cc | 9 +- .../passes/subexpression_migration_pass.cc | 559 ++++++++ .../passes/subexpression_migration_pass.h | 137 ++ src/ge/graph/passes/subgraph_pass.cc | 24 +- src/ge/graph/passes/switch_fusion_pass.cc | 249 ---- src/ge/graph/passes/switch_split_pass.cc | 145 -- .../passes/switch_to_stream_switch_pass.cc | 2 +- .../transop_symmetry_elimination_pass.cc | 44 +- .../transop_symmetry_elimination_pass.h | 11 +- .../transop_without_reshape_fusion_pass.cc | 3 +- src/ge/graph/passes/unused_args_clean_pass.cc | 206 +++ src/ge/graph/passes/unused_args_clean_pass.h | 83 ++ .../graph/passes/variable_prepare_op_pass.cc | 99 +- .../graph/passes/variable_prepare_op_pass.h | 17 +- .../passes/variable_ref_delete_op_pass.cc | 4 + src/ge/graph/preprocess/graph_preprocess.cc | 716 +--------- src/ge/graph/preprocess/graph_preprocess.h | 21 - .../graph/preprocess/insert_op/ge_aipp_op.cc | 213 ++- .../graph/preprocess/insert_op/ge_aipp_op.h | 8 +- .../insert_op/util_insert_aipp_op.cc | 32 +- .../preprocess/multi_batch_copy_graph.cc | 473 ++++--- .../graph/preprocess/multi_batch_copy_graph.h | 26 +- .../graph/preprocess/multi_batch_options.cc | 258 ++++ src/ge/graph/preprocess/multi_batch_options.h | 72 + .../ops_kernel_store/op/assign_op.cc | 51 - .../ops_kernel_store/op/assign_op.h | 41 - .../ops_kernel_store/op/random_uniform_op.cc | 104 -- .../ops_kernel_store/op/random_uniform_op.h | 45 - .../ops_kernel_store/op/variable_op.cc | 46 - .../ops_kernel_store/op/variable_op.h | 41 - .../common/constant/constant.h | 14 +- .../engine/host_cpu_engine.cc} | 38 +- .../engine/host_cpu_engine.h} | 40 +- .../module.mk | 8 +- .../host_cpu_ops_kernel_info.cc} | 27 +- .../host_cpu_ops_kernel_info.h} | 28 +- .../ops_kernel_store/op/host_op.cc | 10 +- .../ops_kernel_store/op/host_op.h | 12 +- .../ops_kernel_store/op/op.h | 10 +- .../ops_kernel_store/op/op_factory.cc | 6 +- .../ops_kernel_store/op/op_factory.h | 12 +- src/ge/host_cpu_engine/proto/task.proto | 1 + src/ge/host_kernels/concat_offset_kernel.cc | 10 +- src/ge/host_kernels/floordiv_kernel.cc | 6 +- src/ge/host_kernels/reduce_prod_kernel.cc | 2 +- src/ge/host_kernels/rsqrt_kernel.cc | 143 +- src/ge/host_kernels/rsqrt_kernel.h | 5 + src/ge/hybrid/common/npu_memory_allocator.h | 5 + .../executor/hybrid_execution_context.cc | 12 +- .../executor/hybrid_execution_context.h | 11 +- .../hybrid/executor/hybrid_model_executor.cc | 16 + src/ge/hybrid/executor/hybrid_profiler.cc | 7 +- src/ge/hybrid/executor/node_done_manager.cc | 2 +- src/ge/hybrid/executor/node_state.cc | 28 +- src/ge/hybrid/executor/subgraph_context.cc | 7 +- src/ge/hybrid/executor/subgraph_executor.cc | 13 +- .../executor/worker/execution_engine.cc | 84 +- .../executor/worker/shape_inference_engine.cc | 77 +- .../executor/worker/shape_inference_engine.h | 3 + .../executor/worker/task_compile_engine.cc | 5 +- src/ge/hybrid/model/hybrid_model.cc | 2 + src/ge/hybrid/model/hybrid_model.h | 2 + src/ge/hybrid/model/hybrid_model_builder.cc | 81 +- src/ge/hybrid/model/hybrid_model_builder.h | 1 + src/ge/hybrid/model/node_item.cc | 75 +- src/ge/hybrid/model/node_item.h | 10 +- .../aicore/aicore_node_executor.cc | 8 + .../aicpu/aicpu_node_executor.cc | 9 +- .../compiledsubgraph/known_node_executor.cc | 30 +- .../controlop/control_op_executor.cc | 104 +- .../controlop/control_op_executor.h | 4 +- .../ge_local_node_executor.cc | 24 +- .../ge_local_node_executor.h | 0 .../node_executor/hccl/hccl_node_executor.cc | 11 +- .../host_cpu_node_executor.cc} | 44 +- .../host_cpu_node_executor.h} | 31 +- .../kernel/assign_kernel.cc | 14 +- .../kernel/assign_kernel.h | 12 +- .../{hostaicpu => host_cpu}/kernel/kernel.h | 10 +- .../kernel/no_op_kernel.cc | 10 +- .../kernel/no_op_kernel.h | 12 +- .../kernel/random_uniform_kernel.cc | 82 +- .../kernel/random_uniform_kernel.h | 12 +- .../kernel/variable_kernel.cc | 12 +- .../kernel/variable_kernel.h | 12 +- .../{hostaicpu => host_cpu}/kernel_factory.cc | 6 +- .../{hostaicpu => host_cpu}/kernel_factory.h | 14 +- src/ge/hybrid/node_executor/node_executor.cc | 8 + src/ge/hybrid/node_executor/node_executor.h | 2 + .../partitioned_call_node_executor.cc | 2 + .../node_executor/rts/rts_node_executor.cc | 91 ++ .../node_executor/rts/rts_node_executor.h | 45 + src/ge/hybrid/node_executor/task_context.cc | 47 +- src/ge/hybrid/node_executor/task_context.h | 6 +- src/ge/init/gelib.cc | 27 +- src/ge/ir_build/atc_ir_common.cc | 119 +- src/ge/ir_build/atc_ir_common.h | 3 + src/ge/ir_build/ge_ir_build.cc | 24 +- src/ge/offline/main.cc | 1227 ----------------- src/ge/offline/module.mk | 53 - src/ge/offline/single_op_parser.cc | 426 ------ src/ge/offline/single_op_parser.h | 79 -- .../opskernel_manager/ops_kernel_manager.cc | 10 +- src/ge/opskernel_manager/ops_kernel_manager.h | 2 +- .../optimizer_priority.pbtxt | 2 +- src/ge/plugin/engine/dnnengines.cc | 18 +- src/ge/plugin/engine/dnnengines.h | 16 +- src/ge/plugin/engine/engine_manage.cc | 39 +- src/ge/session/inner_session.cc | 13 + src/ge/session/inner_session.h | 7 + src/ge/session/omg.cc | 303 +++- src/ge/session/session_manager.cc | 62 + src/ge/session/session_manager.h | 10 + src/ge/single_op/single_op.cc | 21 +- src/ge/single_op/single_op.h | 7 +- src/ge/single_op/single_op_manager.cc | 48 +- src/ge/single_op/single_op_model.cc | 20 +- src/ge/single_op/single_op_model.h | 1 + src/ge/single_op/stream_resource.cc | 94 +- src/ge/single_op/stream_resource.h | 15 +- .../task/aicpu_kernel_task_builder.cc | 13 +- .../task/aicpu_kernel_task_builder.h | 4 +- src/ge/single_op/task/aicpu_task_builder.cc | 1 + src/ge/single_op/task/op_task.cc | 82 +- src/ge/single_op/task/op_task.h | 26 +- src/ge/single_op/task/tbe_task_builder.cc | 28 +- src/ge/single_op/task/tbe_task_builder.h | 8 +- src/proto/op_mapping_info.proto | 11 + .../fwkacllib/inc/cce/fwk_adpt_struct.h | 1 + third_party/fwkacllib/inc/ops/aipp.h | 11 +- .../fwkacllib/inc/ops/control_flow_ops.h | 14 +- .../inc/ops/elewise_calculation_ops.h | 51 +- .../fwkacllib/inc/ops/functional_ops.h | 222 +++ third_party/fwkacllib/inc/ops/image_ops.h | 55 + third_party/fwkacllib/inc/ops/internal_ops.h | 16 + third_party/fwkacllib/inc/ops/math_ops.h | 73 +- .../fwkacllib/inc/ops/nn_calculation_ops.h | 267 +++- third_party/fwkacllib/inc/ops/nn_detect_ops.h | 69 +- third_party/fwkacllib/inc/ops/nn_norm_ops.h | 34 +- .../fwkacllib/inc/ops/nn_pooling_ops.h | 33 + .../fwkacllib/inc/ops/nn_training_ops.h | 42 +- .../fwkacllib/inc/ops/nonlinear_fuc_ops.h | 45 + .../fwkacllib/inc/ops/npu_loss_scale_ops.h | 29 +- third_party/fwkacllib/inc/ops/pad_ops.h | 24 +- third_party/fwkacllib/inc/ops/quantize_ops.h | 72 + third_party/fwkacllib/inc/ops/reduce_ops.h | 39 +- third_party/fwkacllib/inc/ops/rnn.h | 11 + third_party/fwkacllib/inc/ops/selection_ops.h | 54 +- .../fwkacllib/inc/ops/transformation_ops.h | 17 + third_party/fwkacllib/inc/runtime/base.h | 357 ++++- third_party/fwkacllib/inc/runtime/context.h | 2 +- third_party/fwkacllib/inc/runtime/dev.h | 26 +- .../fwkacllib/inc/runtime/dvfsprofile.h | 6 - third_party/fwkacllib/inc/runtime/event.h | 20 +- third_party/fwkacllib/inc/runtime/kernel.h | 90 +- third_party/fwkacllib/inc/runtime/mem.h | 48 +- third_party/fwkacllib/inc/runtime/rt_model.h | 30 +- third_party/fwkacllib/inc/runtime/stream.h | 28 +- third_party/fwkacllib/inc/tdt/tsd_client.h | 78 -- third_party/fwkacllib/inc/toolchain/slog.h | 57 +- 292 files changed, 10490 insertions(+), 6386 deletions(-) create mode 100644 inc/common/util/ai_core/common/aicore_util_attr_define.h create mode 100644 inc/common/util/ai_core/common/aicore_util_types.h create mode 100644 inc/common/util/ai_core/common/graph_comm.h create mode 100644 inc/common/util/ai_core/common/scope_allocator.h rename src/ge/graph/passes/switch_split_pass.h => inc/common/util/ai_core/param_calculate/aicore_param_calculator.h (63%) create mode 100644 inc/common/util/ai_core/param_calculate/tensorsize_calculator.h create mode 100644 inc/external/register/scope/scope_fusion_pass_register.h create mode 100644 src/ge/common/dump/dump_manager.cc create mode 100644 src/ge/common/dump/dump_manager.h create mode 100644 src/ge/common/dump/dump_op.cc create mode 100644 src/ge/common/dump/dump_op.h create mode 100644 src/ge/common/dump/dump_properties.cc create mode 100644 src/ge/common/dump/dump_properties.h delete mode 100644 src/ge/common/model_parser/graph_parser_util.cc delete mode 100644 src/ge/common/model_parser/graph_parser_util.h create mode 100644 src/ge/graph/passes/assign_pass.cc rename src/ge/graph/passes/{switch_fusion_pass.h => assign_pass.h} (53%) create mode 100644 src/ge/graph/passes/global_step_insert_pass.cc create mode 100644 src/ge/graph/passes/global_step_insert_pass.h create mode 100644 src/ge/graph/passes/multi_batch_clone_pass.cc create mode 100644 src/ge/graph/passes/multi_batch_clone_pass.h create mode 100644 src/ge/graph/passes/subexpression_migration_pass.cc create mode 100644 src/ge/graph/passes/subexpression_migration_pass.h delete mode 100644 src/ge/graph/passes/switch_fusion_pass.cc delete mode 100644 src/ge/graph/passes/switch_split_pass.cc create mode 100644 src/ge/graph/passes/unused_args_clean_pass.cc create mode 100644 src/ge/graph/passes/unused_args_clean_pass.h create mode 100644 src/ge/graph/preprocess/multi_batch_options.cc create mode 100644 src/ge/graph/preprocess/multi_batch_options.h delete mode 100644 src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.cc delete mode 100644 src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.h delete mode 100644 src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.cc delete mode 100644 src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h delete mode 100644 src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.cc delete mode 100644 src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.h rename src/ge/{host_aicpu_engine => host_cpu_engine}/common/constant/constant.h (66%) rename src/ge/{host_aicpu_engine/engine/host_aicpu_engine.cc => host_cpu_engine/engine/host_cpu_engine.cc} (52%) rename src/ge/{host_aicpu_engine/engine/host_aicpu_engine.h => host_cpu_engine/engine/host_cpu_engine.h} (71%) rename src/ge/{host_aicpu_engine => host_cpu_engine}/module.mk (88%) rename src/ge/{host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.cc => host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc} (82%) rename src/ge/{host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h => host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h} (70%) rename src/ge/{host_aicpu_engine => host_cpu_engine}/ops_kernel_store/op/host_op.cc (80%) rename src/ge/{host_aicpu_engine => host_cpu_engine}/ops_kernel_store/op/host_op.h (76%) rename src/ge/{host_aicpu_engine => host_cpu_engine}/ops_kernel_store/op/op.h (83%) rename src/ge/{host_aicpu_engine => host_cpu_engine}/ops_kernel_store/op/op_factory.cc (93%) rename src/ge/{host_aicpu_engine => host_cpu_engine}/ops_kernel_store/op/op_factory.h (90%) create mode 120000 src/ge/host_cpu_engine/proto/task.proto rename src/ge/hybrid/node_executor/{hostcpu => ge_local}/ge_local_node_executor.cc (88%) rename src/ge/hybrid/node_executor/{hostcpu => ge_local}/ge_local_node_executor.h (100%) rename src/ge/hybrid/node_executor/{hostaicpu/host_aicpu_node_executor.cc => host_cpu/host_cpu_node_executor.cc} (77%) rename src/ge/hybrid/node_executor/{hostaicpu/host_aicpu_node_executor.h => host_cpu/host_cpu_node_executor.h} (68%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/assign_kernel.cc (82%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/assign_kernel.h (79%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/kernel.h (84%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/no_op_kernel.cc (76%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/no_op_kernel.h (79%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/random_uniform_kernel.cc (53%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/random_uniform_kernel.h (82%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/variable_kernel.cc (76%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel/variable_kernel.h (79%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel_factory.cc (93%) rename src/ge/hybrid/node_executor/{hostaicpu => host_cpu}/kernel_factory.h (87%) create mode 100644 src/ge/hybrid/node_executor/rts/rts_node_executor.cc create mode 100644 src/ge/hybrid/node_executor/rts/rts_node_executor.h delete mode 100644 src/ge/offline/main.cc delete mode 100644 src/ge/offline/module.mk delete mode 100644 src/ge/offline/single_op_parser.cc delete mode 100644 src/ge/offline/single_op_parser.h mode change 100644 => 100755 src/ge/opskernel_manager/optimizer_priority.pbtxt diff --git a/inc/common/optimizer/graph_optimizer.h b/inc/common/optimizer/graph_optimizer.h index c330dd63..2c0cebe6 100644 --- a/inc/common/optimizer/graph_optimizer.h +++ b/inc/common/optimizer/graph_optimizer.h @@ -42,6 +42,9 @@ class GraphOptimizer { // optimize original graph for FE quant optimize virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } + // optimize graph before build for RTS + virtual Status OptimizeGraphBeforeBuild(ComputeGraph &graph) { return SUCCESS; } + // optimize original graph, using in graph preparation stage virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; diff --git a/inc/common/util/ai_core/common/aicore_util_attr_define.h b/inc/common/util/ai_core/common/aicore_util_attr_define.h new file mode 100644 index 00000000..6c20c470 --- /dev/null +++ b/inc/common/util/ai_core/common/aicore_util_attr_define.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_ATTR_DEFINE_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_ATTR_DEFINE_H_ + +#include + +namespace fe { +static const std::string SCOPE_ID_ATTR = "fusion_scope"; + +static const std::string FE_IMPLY_TYPE = "_fe_imply_type"; + +static const std::string PARENT_OP_TYPE = "parentOpType"; + +static const std::string ATTR_NAME_TASK_L2_FUSION_INFO_EXTEND_PTR = "task_l2_fusion_info_extend_content"; + +static const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; + +static const std::string ATTR_NAME_L2_FUSION_EXTEND_PTR = "l2_fusion_extend_content"; + +static const std::string L1_OPTIMIZED = "l1_optimized"; + +static const std::string L2_OPTIMIZED = "l2_optimized"; +} // namespace fe +#endif diff --git a/inc/common/util/ai_core/common/aicore_util_types.h b/inc/common/util/ai_core/common/aicore_util_types.h new file mode 100644 index 00000000..b2615dc9 --- /dev/null +++ b/inc/common/util/ai_core/common/aicore_util_types.h @@ -0,0 +1,118 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_TYPES_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_TYPES_H_ + +#include "graph/anchor.h" +#include "graph/types.h" +#include "runtime/kernel.h" +#include +#include +#include + +namespace fe { +struct FusionOpSrc { + uint32_t src_op_id; + ge::AnchorPtr src_anchor; + int32_t fusion_src_index; + int32_t fusion_dst_index; +}; + +struct FusionOpDst { + uint32_t dst_op_id; + ge::AnchorPtr dst_anchor; +}; + +struct FusionDataFlow { + std::pair edge; + std::pair node_dataindex_pair; +}; + +typedef struct tagL2FusionData { + uint32_t l2Index; + uint64_t l2Addr; + uint64_t l2PageNum; +} L2FusionData_t; +typedef std::map L2FusionDataMap_t; + +typedef struct tagFeSmDesc { + rtL2Ctrl_t l2ctrl; + std::string nodeName[8]; + uint8_t outputIndex[8]; +} feSmDesc_t; + +typedef struct TagTaskL2FusionInfo { + std::string nodeName; + feSmDesc_t l2Info; + L2FusionDataMap_t input; + L2FusionDataMap_t output; + uint32_t isUsed; +} TaskL2FusionInfo_t; + +using L2FusionInfoPtr = std::shared_ptr; + +typedef struct ToOpStruct { + int64_t opL1Space = 0; + std::vector opL1FusionType; + int64_t opL1WorkspaceFlag = 0; // for workspace flag + int64_t opL1WorkspaceSize = 0; + std::vector> validInputShape; + std::vector> validOutputShape; + std::vector> sliceInputOffset; // conv & pooling & ReadSelect + std::vector> sliceOutputOffset; // WriteSelect + std::vector totalShape; + uint32_t splitIndex = 0; + ToOpStruct() { + // set invalid value for essential variable + opL1Space = -1; + opL1WorkspaceSize = -1; + } +} ToOpStruct_t; + +enum OpImplType { + EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op + EN_IMPL_CUSTOM_TIK, // custom tik op + EN_IMPL_CUSTOM_TBE, // custom tbe op + EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op + EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op + EN_IMPL_HW_TIK, // Huawei built-in tik op + EN_IMPL_HW_TBE, // Huawei built-in tbe op + EN_IMPL_RL, // RL op + EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op + EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op + EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op + EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op + EN_RESERVED // reserved value +}; + +static const std::map DATATYPE_SIZE_MAP{{ge::DT_FLOAT, sizeof(float)}, + {ge::DT_FLOAT16, sizeof(int16_t)}, + {ge::DT_INT8, sizeof(int8_t)}, + {ge::DT_INT32, sizeof(int32_t)}, + {ge::DT_UINT8, sizeof(uint8_t)}, + {ge::DT_UINT32, sizeof(uint32_t)}, + {ge::DT_INT16, sizeof(int16_t)}, + {ge::DT_UINT16, sizeof(uint16_t)}, + {ge::DT_INT64, sizeof(int64_t)}, + {ge::DT_UINT64, sizeof(uint64_t)}, + {ge::DT_DOUBLE, sizeof(double)}, + {ge::DT_BOOL, sizeof(bool)}, + {ge::DT_DUAL, sizeof(float) + sizeof(int8_t)}, + {ge::DT_DUAL_SUB_UINT8, sizeof(int8_t)}, + {ge::DT_DUAL_SUB_INT8, sizeof(int8_t)}}; +} // namespace fe +#endif diff --git a/inc/common/util/ai_core/common/graph_comm.h b/inc/common/util/ai_core/common/graph_comm.h new file mode 100644 index 00000000..d672e056 --- /dev/null +++ b/inc/common/util/ai_core/common/graph_comm.h @@ -0,0 +1,107 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_GRAPH_COMMON_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_GRAPH_COMMON_H_ + +#include "graph/compute_graph.h" +#include "common/aicore_util_types.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +#include +#include +#include +#include + +namespace fe { + +using kScopeNodeMap_t = std::map>; +using kScopeNodePair_t = std::pair>; + +class GraphCommImpl; +using GraphCommImplPtr = std::unique_ptr; + +class GraphComm { + public: + GraphComm(const string &engineName); + virtual ~GraphComm(); + GraphComm(const GraphComm &in) = delete; + GraphComm &operator=(const GraphComm &in) = delete; + + Status GetscopeNodeMap(ge::ComputeGraph &graph, kScopeNodeMap_t &fusionMap); + + Status CopyFusionOpNodes(vector &fusInputEdgeList, vector &fusOutputEdgeList, + vector &fusNodelist, ge::OpDescPtr fusionOpDesc, + ge::ComputeGraphPtr fusionGraph); + + Status CopyFusionOpEdges(ge::OpDescPtr fusionOpDesc, ge::ComputeGraph &origGraph, ge::ComputeGraphPtr fusionGraph); + + Status GetNodeDataFlowMap(const ge::NodePtr &fusNode, + std::map> &fusionOpAnchorsMap, + ge::kFusionDataFlowVec_t &fusDataflowList, const int &mapType); + + Status GetFusionNodeEdgeList(std::vector &fusNodelist, std::vector &fusInputEdgeList, + std::vector &fusOutputEdgeList); + void ClearFusionSrc(); + + void ClearFusionDst(); + + void AddFusionOutputSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, const int32_t &fusion_src_index, + std::pair &node_dataindex_pair); + + void AddFusionInputSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, const int32_t &fusion_dst_index, + std::pair &node_dataindex_pair); + + void SaveFusionDst(const uint32_t &dst_op_id, ge::AnchorPtr dst_anchor); + + bool IsFusionDstExist(const uint32_t &dst_op_id, const ge::AnchorPtr &dst_anchor); + + bool GetFusionSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, int32_t &fusion_src_index, + int32_t &fusion_dst_index); + + Status GetFusionNodeCtrlEdgeList(vector &fusNodelist, vector &fusInputCtrlEdgeList, + vector &fusOutputCtrlEdgeList); + + Status MergeFusionNodeEdgeList(ge::NodePtr &fusNode, vector &fusNodelist, + vector &fusInputEdgeList, vector &fusOutputEdgeList); + + Status MergeFusionNodeCtrlEdgeList(ge::NodePtr &fusNode, vector &fusNodelist, + vector &fusInputEdgeList, + vector &fusOutputEdgeList); + + string GetEngineName(); + + private: + Status MergeFusionNodeInputEdgeList(ge::NodePtr fusNode, std::vector &fusNodelist, + std::vector &fusInputEdgeList); + Status MergeFusionNodeOutputEdgeList(ge::NodePtr fusNode, std::vector &fusNodelist, + std::vector &fusOutputEdgeList); + + string engineName_; + + std::vector exist_fusion_src_list_; + std::vector exist_fusion_dst_list_; + + // std::vector> + ge::kFusionDataFlowVec_t fusion_input_dataflow_list_; + + // std::vector> + ge::kFusionDataFlowVec_t fusion_output_dataflow_list_; + + GraphCommImplPtr graphCommImplPtr_; +}; +} // namespace fe +#endif diff --git a/inc/common/util/ai_core/common/scope_allocator.h b/inc/common/util/ai_core/common/scope_allocator.h new file mode 100644 index 00000000..3b264425 --- /dev/null +++ b/inc/common/util/ai_core/common/scope_allocator.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_SCOPE_ALLOCATOR_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_SCOPE_ALLOCATOR_H_ + +#include "graph/op_desc.h" + +namespace fe { +class ScopeAllocator { + public: + ScopeAllocator(); + virtual ~ScopeAllocator(); + ScopeAllocator(const ScopeAllocator& in) = delete; + ScopeAllocator& operator=(const ScopeAllocator& in) = delete; + + public: + void Init(); + int64_t GetCurrentScopeId(); + int64_t AllocateScopeId(void); + bool HasScopeAttr(ge::ConstOpDescPtr opdef); + bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scopeId); + bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scopeId); + + private: + int64_t scopeId; +}; +} // namespace fe +#endif diff --git a/src/ge/graph/passes/switch_split_pass.h b/inc/common/util/ai_core/param_calculate/aicore_param_calculator.h similarity index 63% rename from src/ge/graph/passes/switch_split_pass.h rename to inc/common/util/ai_core/param_calculate/aicore_param_calculator.h index 69ab01c3..c0c378fd 100644 --- a/src/ge/graph/passes/switch_split_pass.h +++ b/inc/common/util/ai_core/param_calculate/aicore_param_calculator.h @@ -14,15 +14,20 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_SWITCH_SPLIT_PASS_H_ -#define GE_GRAPH_PASSES_SWITCH_SPLIT_PASS_H_ +#ifndef AICORE_PARAM_CALCULATOR +#define AICORE_PARAM_CALCULATOR -#include -#include "graph/passes/base_pass.h" -namespace ge { -class SwitchSplitPass : public BaseNodePass { +#include "graph/node.h" +#include "graph_optimizer/graph_optimize_register_error_codes.h" + +namespace fe { +class AICoreParamCalculator { public: - Status Run(NodePtr &node) override; + AICoreParamCalculator(); + + ~AICoreParamCalculator(); + + Status CalcOpRunningParam(ge::Node &node); }; -} // namespace ge -#endif // GE_GRAPH_PASSES_SWITCH_SPLIT_PASS_H_ +} // namespace fe +#endif // AICORE_PARAM_CALCULATOR diff --git a/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h b/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h new file mode 100644 index 00000000..c82cca4b --- /dev/null +++ b/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORSIZE_CALCULATOR_H +#define TENSORSIZE_CALCULATOR_H + +#include "graph_optimizer/graph_optimize_register_error_codes.h" + +#include +#include +#include "graph/compute_graph.h" +#include "graph/op_desc.h" + +namespace fe { +class TensorSizeCalculator { + public: + /** + * Calculate the tensor size of input and output of each opdesc + * @param opDesc opdesc object + * @param opImplType op impl type + * @return status SUCCESS or FAILED + */ + static Status CalculateOpTensorSize(ge::OpDesc &opDesc); + + private: + static Status CalcInputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag); + + static Status CalcOutputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag); +}; +} // namespace fe + +#endif // TENSORSIZE_CALCULATOR_H diff --git a/inc/external/ge/ge_api.h b/inc/external/ge/ge_api.h index 08156539..b4b9bb2a 100644 --- a/inc/external/ge/ge_api.h +++ b/inc/external/ge/ge_api.h @@ -98,6 +98,15 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { /// Status RunGraphAsync(uint32_t graphId, const std::vector &inputs, RunAsyncCallback callback); + /// + /// @ingroup ge_graph + /// @brief get variables in the session with specific session id + /// @param [in] var_names: variable names + /// @param [out] var_values: variable values + /// @return Status result of function + /// + Status GetVariables(const std::vector &var_names, std::vector &var_values); + /// /// @ingroup ge_graph /// @brief register callback func with specific summary or checkpoint by users diff --git a/inc/external/ge/ge_ir_build.h b/inc/external/ge/ge_ir_build.h index 5982ae90..acf6991a 100644 --- a/inc/external/ge/ge_ir_build.h +++ b/inc/external/ge/ge_ir_build.h @@ -23,6 +23,12 @@ #include "graph/graph.h" #include "graph/ge_error_codes.h" +namespace { +#define IR_MAJOR_VERSION (int(1)) +#define IR_MINOR_VERSION (int(0)) +#define IR_PATCH_VERSION (int(0)) +} // namespace + namespace ge { struct ModelBufferData { @@ -71,5 +77,17 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map; using OperatorImplPtr = std::shared_ptr; @@ -65,8 +67,8 @@ using std::string; class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { public: friend class OperatorImpl; - friend class GraphBuilderImpl; + friend class NodeUtils; using OpInt = int64_t; using OpFloat = float; @@ -104,6 +106,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 + Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); + Operator &AddControlInput(const Operator &src_oprt); graphStatus GetInputConstData(const string &dst_name, Tensor &data) const; @@ -269,11 +273,15 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { OutHandler GetOutput(const string &name) const; + OutHandler GetOutput(uint32_t index) const; + OperatorImplPtr GetOperatorImplPtr() const; OperatorImplPtr operator_impl_{nullptr}; graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; + + std::shared_ptr GetNode() const; }; /*lint +e148*/ } // namespace ge diff --git a/inc/external/graph/operator_reg.h b/inc/external/graph/operator_reg.h index d155f4bd..f0e1e84a 100644 --- a/inc/external/graph/operator_reg.h +++ b/inc/external/graph/operator_reg.h @@ -130,6 +130,10 @@ class OpReg { Operator::SetInput(#x, v, srcName); \ return *this; \ } \ + _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ + Operator::SetInput(#x, v, index); \ + return *this; \ + } \ _THIS_TYPE &set_input_##x(Operator &v) { \ Operator::SetInput(#x, v); \ return *this; \ @@ -159,6 +163,10 @@ class OpReg { Operator::SetInput(#x, v, srcName); \ return *this; \ } \ + _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ + Operator::SetInput(#x, v, index); \ + return *this; \ + } \ TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ return Operator::UpdateInputDesc(#x, tensorDesc); \ diff --git a/inc/external/register/scope/scope_fusion_pass_register.h b/inc/external/register/scope/scope_fusion_pass_register.h new file mode 100644 index 00000000..77be4b8c --- /dev/null +++ b/inc/external/register/scope/scope_fusion_pass_register.h @@ -0,0 +1,331 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ +#define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ + +#include +#include +#include +#include +#include "ge/ge_api_error_codes.h" +#include "register/register_error_codes.h" +#include "register/register_types.h" +#include "graph/operator.h" + +#define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \ + do { \ + if (!(cond)) { \ + if ((fusion_rlt) != nullptr) { \ + (fusion_rlt)->SetType(ge::kScopeInvalidType); \ + } \ + return; \ + } \ + } while (0) + +namespace domi { +class TensorFlowModelParser; +} // namespace domi +namespace ge { +const int32_t kFusionDisableIndex = 99999; +const char *const kScopeToMultiNodes = "ScopeToMultiNodes"; +const char *const kScopeInvalidType = "ScopeInvalidType"; +const char *const kInputFromFusionScope = "InputFromFusionScope"; +const char *const kOutputToFusionScope = "OutputToFusionScope"; +class ScopePattern; +using ScopeFusionPatterns = std::vector>; + +class ScopePassManager; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { + public: + explicit Scope(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); + ~Scope(); + + std::string Name() const; + std::string SubType() const; + std::map AllNodesMap() const; + Scope *GetSubScope(const std::string &scope_name) const; + std::string LastName() const; + std::vector GetAllSubScopes() const; + const Scope *GetFatherScope() const; + + private: + class ScopeImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; + friend class ScopeTree; + friend class NodeOpTypeFeature; + friend class NodeAttrFeature; + friend class ScopeFeature; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { + public: + FusionScopesResult(); + ~FusionScopesResult(); + void SetName(const std::string &name); + void SetType(const std::string &type); + void SetDescription(const std::string &description); + std::string Name() const; + std::vector Nodes() const; + void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); + void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); + + class InnerNodeInfo { + public: + explicit InnerNodeInfo(const std::string &fusion_node_name); + InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type); + InnerNodeInfo(InnerNodeInfo &&other) noexcept; + InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept; + InnerNodeInfo(const InnerNodeInfo &) = delete; + InnerNodeInfo &operator=(const InnerNodeInfo &) = delete; + ~InnerNodeInfo(); + InnerNodeInfo &SetName(const std::string &name); + InnerNodeInfo &SetType(const std::string &type); + InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx); + InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx); + ge::graphStatus BuildInnerNode(); + ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format); + ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); + ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); + ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); + ge::Operator *MutableOperator(); + + std::string GetName() const; + std::string GetType() const; + std::vector> GetInputs() const; + std::vector> GetOutputs() const; + + private: + class InnerNodeInfoImpl; + std::unique_ptr impl_; + }; + + InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); + InnerNodeInfo *MutableRecentInnerNode(); + InnerNodeInfo *MutableInnerNode(uint32_t index); + ge::graphStatus CheckInnerNodesInfo(); + + private: + class FusionScopesResultImpl; + std::unique_ptr impl_; + friend class ScopeGraph; + friend class ScopeBasePass; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { + public: + ScopeTree(); + Status Init(); + ScopeTree(const ScopeTree &scopetree) = delete; + ScopeTree &operator=(const ScopeTree &scopetree) = delete; + ~ScopeTree(); + + std::vector GetAllScopes() const; + + private: + class ScopeTreeImpl; + std::unique_ptr impl_; + friend class ScopeGraph; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { + public: + ScopeGraph(); + Status Init(); + ScopeGraph(const ScopeGraph &scope_graph) = delete; + ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete; + ~ScopeGraph(); + + const ScopeTree *GetScopeTree() const; + std::map GetNodesMap() const; + + private: + class ScopeGraphImpl; + std::unique_ptr impl_; + friend class ScopePassManager; + friend class ScopeBasePass; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue { + public: + ScopeAttrValue(); + ScopeAttrValue(ScopeAttrValue const &attr_value); + ScopeAttrValue &operator=(ScopeAttrValue const &attr_value); + ~ScopeAttrValue(); + + void SetIntValue(int64_t value); + void SetFloatValue(float value); + void SetStringValue(std::string value); + void SetBoolValue(bool value); + + private: + class ScopeAttrValueImpl; + std::unique_ptr impl_; + friend class NodeAttrFeature; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature { + public: + virtual bool Match(const Scope *scope) = 0; + virtual ~ScopeBaseFeature(){}; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature { + public: + NodeOpTypeFeature(std::string nodeType, int num, int step = 0); + NodeOpTypeFeature(NodeOpTypeFeature const &feature); + NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature); + ~NodeOpTypeFeature(); + bool Match(const Scope *scope) override; + + private: + class NodeOpTypeFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { + public: + NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue attr_value); + NodeAttrFeature(NodeAttrFeature const &feature); + NodeAttrFeature &operator=(NodeAttrFeature const &feature); + ~NodeAttrFeature(); + bool Match(const Scope *scope) override; + + private: + class NodeAttrFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature { + public: + ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "", + int step = 0); + ScopeFeature(ScopeFeature const &feature); + ScopeFeature &operator=(ScopeFeature const &feature); + ~ScopeFeature(); + bool Match(const Scope *scope) override; + + private: + class ScopeFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern { + public: + ScopePattern(); + ~ScopePattern(); + + ScopePattern &SetSubType(const std::string &sub_type); + ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature); + ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature); + ScopePattern &AddScopeFeature(ScopeFeature feature); + + private: + class ScopePatternImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult { + public: + ScopesResult(); + ScopesResult(ScopesResult const &result); + ScopesResult &operator=(ScopesResult const &result); + ~ScopesResult(); + + void SetScopes(std::vector &scopes); + void SetNodes(std::vector &nodes); + + private: + class ScopesResultImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass { + public: + ScopeBasePass(); + virtual ~ScopeBasePass(); + + protected: + // Subclasses implement respective fusion strategies and build the Patterns + virtual std::vector DefinePatterns() = 0; + // Define the name of the scope pass + virtual std::string PassName() = 0; + // Subclasses implement respective multi-scope or operator fusion methods across scopes + virtual Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, + std::vector &results) = 0; + // Subclasses implement their own results and set the input and output of the final fusion operator + virtual void GenerateFusionResult(const std::vector &scopes, FusionScopesResult *fusion_rlt) = 0; + + private: + class ScopeBasePassImpl; + std::unique_ptr impl_; + friend class ge::ScopePassManager; + friend class ScopeBasePassImpl; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { + public: + using CreateFn = ScopeBasePass *(*)(); + ~ScopeFusionPassRegistry(); + + static ScopeFusionPassRegistry &GetInstance() { + static ScopeFusionPassRegistry instance; + return instance; + } + + void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general); + + private: + ScopeFusionPassRegistry(); + class ScopeFusionPassRegistryImpl; + /*lint -e148*/ + std::unique_ptr impl_; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil { + public: + static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value); + static void FreeScopePatterns(ScopeFusionPatterns &patterns); + static void FreeOneBatchPattern(std::vector &one_batch_pattern); +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar { + public: + ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general); + ~ScopeFusionPassRegistrar() {} +}; + +#define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \ + REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general) + +#define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \ + REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) + +#define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \ + static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \ + ::ge::ScopeFusionPassRegistrar( \ + pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general) +} // namespace ge + +#endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ diff --git a/inc/framework/common/ge_types.h b/inc/framework/common/ge_types.h index 00bfa301..3686befc 100644 --- a/inc/framework/common/ge_types.h +++ b/inc/framework/common/ge_types.h @@ -22,7 +22,7 @@ #include #include -#include "common/fmk_error_codes.h" +#include "framework/common/fmk_error_codes.h" #include "ge/ge_api_error_codes.h" #include "external/graph/types.h" #include "external/ge/ge_api_types.h" @@ -49,6 +49,7 @@ enum OpEngineType { }; const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; +const char *const GE_OPTION_EXEC_PLACEMENT = "ge.exec.placement"; // Data cache, including data address and length struct DataBuffer { @@ -128,6 +129,7 @@ struct OriginInputInfo { // The structure of AIPP info struct AippConfigInfo { + int8_t aipp_mode; int8_t input_format; int32_t src_image_size_w; int32_t src_image_size_h; @@ -175,6 +177,9 @@ struct AippConfigInfo { float var_reci_chn_1; float var_reci_chn_2; float var_reci_chn_3; + int8_t support_rotation; + uint32_t related_input_rank; + uint32_t max_src_image_size; }; // The structure of offline Modeldata @@ -250,5 +255,31 @@ struct ComputeGraphDescInfo { std::vector> output_shape; std::vector output_data_type; }; + +struct OpDescInfo { + std::string op_name; + uint32_t task_id; + uint32_t stream_id; + std::vector input_format; + std::vector> input_shape; + std::vector input_data_type; + std::vector input_addrs; + std::vector output_format; + std::vector> output_shape; + std::vector output_data_type; + std::vector output_addrs; +}; +struct ModelDumpConfig { + std::string model_name; + std::vector layers; +}; + +struct DumpConfig { + std::string dump_path; + std::string dump_mode; + std::string dump_status; + std::string dump_op_switch; + std::vector dump_list; +}; } // namespace ge #endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index db692c36..189c63c3 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -606,6 +606,7 @@ static constexpr uint32_t MODEL_FILE_RESERVED_LENGTH = 79; /// @brief INPUT node type /// FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string INPUT_TYPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMMY_DATA; /// /// @ingroup domi_omg diff --git a/inc/framework/common/util.h b/inc/framework/common/util.h index 952ce955..71c9367c 100644 --- a/inc/framework/common/util.h +++ b/inc/framework/common/util.h @@ -347,6 +347,14 @@ std::string ToString(const google::protobuf::RepeatedField &rpd_field) { /// uint64_t GetCurrentTimestap(); +/// +/// @ingroup domi_common +/// @brief Obtains the absolute time (timestamp) of the current system. +/// @return Timestamp, in seconds (US) +/// +/// +uint32_t GetCurrentSecondTimestap(); + /// /// @ingroup domi_common /// @brief Check whether the product of two int64 numbers exceeds the int64 range. diff --git a/inc/framework/engine/dnnengine.h b/inc/framework/engine/dnnengine.h index 142ac229..65897ac5 100644 --- a/inc/framework/engine/dnnengine.h +++ b/inc/framework/engine/dnnengine.h @@ -31,6 +31,7 @@ enum PriorityEnum { COST_1, COST_2, COST_9 = 9, + COST_10 = 10, }; struct DNNEngineAttribute { diff --git a/inc/framework/executor/ge_executor.h b/inc/framework/executor/ge_executor.h index 129b8613..f9fa4ce9 100644 --- a/inc/framework/executor/ge_executor.h +++ b/inc/framework/executor/ge_executor.h @@ -135,6 +135,15 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { /// ge::Status GetCombinedDynamicDims(uint32_t model_id, std::vector> &batch_info); + /// + /// @ingroup ge + /// @brief Get user designeate shape order + /// @param [in] model_id + /// @param [out] user_designate_shape_order + /// @return execute result + /// + ge::Status GetUserDesignateShapeOrder(uint32_t model_id, std::vector &user_designate_shape_order); + ge::Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); /// @@ -162,6 +171,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { ge::Status CommandHandle(const ge::Command &command); + ge::Status SetDump(const DumpConfig &dump_config); + /// /// @ingroup ge /// @brief Query model memory consuming interface @@ -261,6 +272,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { ge::Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &input_dims, std::vector &output_dims); + ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); private: static bool isInit_; diff --git a/inc/framework/memory/memory_api.h b/inc/framework/memory/memory_api.h index 656e4710..ebb7e68c 100644 --- a/inc/framework/memory/memory_api.h +++ b/inc/framework/memory/memory_api.h @@ -27,6 +27,7 @@ namespace ge { enum MemStorageType { HBM = 0, RDMA_HBM, + HOST_DDR, }; struct HostVarInfo { diff --git a/inc/framework/omg/omg.h b/inc/framework/omg/omg.h index 6a120439..45a8896d 100644 --- a/inc/framework/omg/omg.h +++ b/inc/framework/omg/omg.h @@ -96,6 +96,10 @@ Status CheckCustomAiCpuOpLib(); Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); +Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); + +Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info); + void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, std::vector &output_nodes_name); diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index 57e389e8..714375e4 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -883,6 +883,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_ // Assign GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VAR_NAME; // ShapeN GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; @@ -939,6 +940,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER; // Control flow GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; @@ -957,7 +959,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM // Function Op GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_CONST_TYPE; // Used for mark the active node is for loop, type:bool GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; @@ -968,6 +969,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_RANGE; + // Atomic addr clean attrs GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_INPUT_INDEX; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_OUTPUT_INDEX; diff --git a/inc/graph/ge_context.h b/inc/graph/ge_context.h index af6b35bc..53985e9c 100644 --- a/inc/graph/ge_context.h +++ b/inc/graph/ge_context.h @@ -24,6 +24,7 @@ namespace ge { class GEContext { public: graphStatus GetOption(const std::string &key, std::string &option); + bool GetHostExecFlag(); uint64_t SessionId(); uint32_t DeviceId(); uint64_t TraceId(); diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h index 1457aa15..27c91efc 100644 --- a/inc/graph/op_desc.h +++ b/inc/graph/op_desc.h @@ -153,9 +153,6 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); - void RemoveInputDesc(uint32_t index); - void RemoveOutputDesc(uint32_t index); - bool IsOptionalInput(const string &name) const; bool IsOptionalInput(uint32_t index) const; diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h index 019bb3a7..bf57148d 100644 --- a/inc/graph/utils/node_utils.h +++ b/inc/graph/utils/node_utils.h @@ -20,6 +20,7 @@ #include #include #include +#include "external/graph/operator.h" #include "graph/node.h" namespace ge { @@ -63,8 +64,11 @@ class NodeUtils { static void UnlinkAll(const Node &node); static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr); - static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t index); - static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t index); + static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t num); + static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t num); + + static graphStatus AppendOutputAnchor(const NodePtr &node, uint32_t num); + static graphStatus RemoveOutputAnchor(const NodePtr &node, uint32_t num); static bool IsInNodesEmpty(const Node &node); static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index); @@ -77,6 +81,7 @@ class NodeUtils { static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); static std::string GetNodeType(const Node &node); + static std::string GetNodeType(const NodePtr &node); static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); @@ -100,8 +105,17 @@ class NodeUtils { /// @param [in] node /// @return Node /// + static NodePtr GetParentInput(const Node &node); static NodePtr GetParentInput(const NodePtr &node); + /// + /// @brief Get is dynamic shape graph from node. + /// @param [in] node + /// @return bool + /// + static bool IsDynamicShape(const Node &node); + static bool IsDynamicShape(const NodePtr &node); + /// /// @brief Check is varying_input for while node /// @param [in] node: Data node for subgraph @@ -115,7 +129,7 @@ class NodeUtils { /// @param [out] string /// @return bool /// - static bool GetConstOpType(const NodePtr &in_node, std::string &op_type); + static bool GetConstOpType(const NodePtr &node, std::string &type); /// /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. @@ -138,9 +152,15 @@ class NodeUtils { /// static vector GetSubgraphOutputNodes(const Node &node); - static NodePtr GetInDataNodeByIndex(const Node &node, int index); + static NodePtr GetInDataNodeByIndex(const Node &node, const int index); + + static vector> GetOutDataNodesWithAnchorByIndex(const Node &node, const int index); + + static ge::ConstNodePtr GetNodeFromOperator(const Operator &oprt); + + static graphStatus GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor); - static vector GetOutDataNodesByIndex(const Node &node, int index); + static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor); private: static std::map> map_send_info_; diff --git a/inc/graph/utils/type_utils.h b/inc/graph/utils/type_utils.h index aba2bdbf..38509b9a 100644 --- a/inc/graph/utils/type_utils.h +++ b/inc/graph/utils/type_utils.h @@ -34,6 +34,7 @@ class TypeUtils { static bool IsFormatValid(Format format); static bool IsInternalFormat(Format format); + static std::string ImplyTypeToSerialString(domi::ImplyType imply_type); static std::string DataTypeToSerialString(DataType data_type); static DataType SerialStringToDataType(const std::string &str); static std::string FormatToSerialString(Format format); diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc index f78ca7aa..fde03a43 100644 --- a/src/common/graph/ge_attr_define.cc +++ b/src/common/graph/ge_attr_define.cc @@ -830,6 +830,7 @@ const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; // Assign const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; +const std::string ASSIGN_VAR_NAME = "_assign_var_name"; // space2bacth batch2space const std::string BATCH_SPACE_ATTR_BLOCK = "block"; @@ -931,7 +932,6 @@ const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; // Function Op const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; -const std::string ATTR_NAME_PARENT_CONST_TYPE = "_parent_const_type"; // Used for mark the active node is for loop, type:bool const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; @@ -942,6 +942,8 @@ const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT = "memory_type_output"; const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; +const std::string ATTR_NAME_MEMORY_TYPE_RANGE = "_memory_type_range"; + const std::string MODEL_ATTR_SESSION_ID = "session_id"; // lx fusion @@ -991,6 +993,8 @@ const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type"; +const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER = "user_designate_shape_order"; + // For inserted op const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; diff --git a/src/common/graph/op_desc.cc b/src/common/graph/op_desc.cc index a7451641..0b22eb83 100644 --- a/src/common/graph/op_desc.cc +++ b/src/common/graph/op_desc.cc @@ -684,18 +684,6 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int return GRAPH_SUCCESS; } -void OpDesc::RemoveInputDesc(uint32_t index) { - while (inputs_desc_.size() > index) { - inputs_desc_.pop_back(); - } -} - -void OpDesc::RemoveOutputDesc(uint32_t index) { - while (outputs_desc_.size() > index) { - outputs_desc_.pop_back(); - } -} - bool OpDesc::IsOptionalInput(const string &name) const { return optional_input_names_.find(name) != optional_input_names_.end(); } diff --git a/src/common/graph/operator.cc b/src/common/graph/operator.cc index 03d4221e..1320e4c2 100644 --- a/src/common/graph/operator.cc +++ b/src/common/graph/operator.cc @@ -277,6 +277,22 @@ class OperatorImpl : public std::enable_shared_from_this { return output_ptr; } + OutHandler GetOutput(uint32_t index) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); + + string name = op_desc_->GetOutputNameByIndex(index); + if (name.empty()) { + GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index); + return nullptr; + } + shared_ptr output_ptr = ComGraphMakeShared(name, index, shared_from_this()); + if (output_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OpIO make shared failed"); + return nullptr; + } + return output_ptr; + } + GeTensorDesc GetOutputDesc(const string &name) const { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); @@ -540,6 +556,13 @@ Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &sr return *this; } +Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) { + auto out_handler = src_oprt.GetOutput(index); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); + (void)SetInput(dst_name, out_handler); + return *this; +} + Operator &Operator::AddControlInput(const Operator &src_oprt) { if (operator_impl_ == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr."); @@ -621,6 +644,11 @@ graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) return GRAPH_FAILED; } +std::shared_ptr Operator::GetNode() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetNode(); +} + TensorDesc Operator::GetInputDesc(const std::string &name) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); @@ -657,6 +685,11 @@ OutHandler Operator::GetOutput(const string &name) const { return operator_impl_->GetOutput(name); } +OutHandler Operator::GetOutput(uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetOutput(index); +} + TensorDesc Operator::GetOutputDesc(const std::string &name) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name)); @@ -1540,6 +1573,7 @@ void GraphUtils::BreakConnect(const std::map &all_node } op_impl->ClearOutputLinks(); op_impl->ClearInputLinks(); + OperatorKeeper::GetInstance().CheckOutOperator(op_impl); } } } // namespace ge diff --git a/src/common/graph/option/ge_context.cc b/src/common/graph/option/ge_context.cc index f5f5e4c9..421e0aff 100644 --- a/src/common/graph/option/ge_context.cc +++ b/src/common/graph/option/ge_context.cc @@ -17,12 +17,14 @@ #include "./ge_context.h" #include "./ge_global_options.h" #include "./ge_local_context.h" +#include "framework/common/ge_types.h" #include "framework/common/debug/ge_log.h" namespace ge { namespace { const int64_t kMinTrainingTraceJobId = 256; const int kDecimal = 10; +const char *kHostExecPlacement = "HOST"; } // namespace GEContext &GetContext() { static GEContext ge_context{}; @@ -33,6 +35,16 @@ graphStatus GEContext::GetOption(const std::string &key, std::string &option) { return GetThreadLocalContext().GetOption(key, option); } +bool GEContext::GetHostExecFlag() { + std::string exec_placement; + if (GetThreadLocalContext().GetOption(GE_OPTION_EXEC_PLACEMENT, exec_placement) != GRAPH_SUCCESS) { + GELOGW("get option OPTION_EXEC_PLACEMENT failed."); + return false; + } + GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str()); + return exec_placement == kHostExecPlacement; +} + std::map &GetMutableGlobalOptions() { static std::map global_options{}; return global_options; diff --git a/src/common/graph/ref_relation.cc b/src/common/graph/ref_relation.cc index 7785bc43..9a9f66ba 100644 --- a/src/common/graph/ref_relation.cc +++ b/src/common/graph/ref_relation.cc @@ -243,8 +243,8 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile( } auto in_data_anchor_idx = in_anchor->GetIdx(); auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); - int ref_d; - int ref_n; + int ref_d = 0; + int ref_n = 0; (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d); (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n); diff --git a/src/common/graph/shape_refiner.cc b/src/common/graph/shape_refiner.cc index 479ec1cb..35c109af 100644 --- a/src/common/graph/shape_refiner.cc +++ b/src/common/graph/shape_refiner.cc @@ -351,6 +351,66 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { } return UpdateParentNodeForBranch(node, ref_out_tensors); } + +string Serial(const vector &dims) { + string serial_string; + serial_string += "["; + for (int64_t dim : dims) { + serial_string += std::to_string(dim) + " "; + } + serial_string += "]"; + return serial_string; +} + +graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); + GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); + for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { + auto in_idx = in_anchor->GetIdx(); + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + continue; + } + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { + continue; + } + int peer_out_idx = peer_out_data_anchor->GetIdx(); + auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast(in_idx)); + auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast(peer_out_idx)); + + // check shape and dtype continuity. do not stop process + auto in_shape = in_desc->GetShape().GetDims(); + auto in_dtype = in_desc->GetDataType(); + auto peer_out_shape = peer_out_desc->GetShape().GetDims(); + auto peer_out_dtype = peer_out_desc->GetDataType(); + if (peer_out_dtype != in_dtype) { + GELOGW( + "current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th " + "output_dtype is [%s].The two dtype should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), + peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); + } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { + string in_shape_str = Serial(in_shape); + string peer_out_shape_str = Serial(peer_out_shape); + GELOGW( + "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " + "input_shape is [%s].The two shape should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx, + peer_out_shape_str.c_str()); + } + // refresh current node input desc + in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); + in_desc->SetShape(peer_out_desc->GetShape()); + in_desc->SetDataType(peer_out_desc->GetDataType()); + in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); + std::vector> shape_range; + (void)peer_out_desc->GetShapeRange(shape_range); + in_desc->SetShapeRange(shape_range); + ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast(peer_out_desc->GetShape().GetDims().size())); + } + return GRAPH_SUCCESS; +} } // namespace void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { if (!IsLogEnable(GE, DLOG_DEBUG)) { @@ -427,9 +487,7 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator & return InferShapeAndType(node, op, true); } graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { - GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); auto op_desc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); const auto &op_type = op_desc->GetType(); graphStatus ret; @@ -554,6 +612,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, bool before_subgraph) { GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); + bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + auto opdesc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); + // some op can not infershape twice such as aipp + bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified"); + if (need_update_input) { + auto status = UpdateOpInputDesc(node); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "update op input_desc failed!"); + return status; + } + } + if (node->Verify() != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); return GRAPH_FAILED; @@ -561,7 +632,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh PrintInOutTensorShape(node, "before_infershape"); Operator op = OpDescUtils::CreateOperatorFromNode(node); - bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); if (!is_unknown_graph) { auto inference_context = CreateInferenceContext(context_map, node); if (inference_context == nullptr) { @@ -574,7 +644,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh graphStatus status = InferShapeAndType(node, op, before_subgraph); if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { - (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); + if (is_unknown_graph) { + return GRAPH_SUCCESS; + } + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); + output_tensor->SetOriginShape(output_tensor->GetShape()); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + } } else { GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); return GRAPH_FAILED; diff --git a/src/common/graph/utils/node_utils.cc b/src/common/graph/utils/node_utils.cc index 35a842e5..72981d10 100644 --- a/src/common/graph/utils/node_utils.cc +++ b/src/common/graph/utils/node_utils.cc @@ -15,6 +15,7 @@ */ #include "utils/node_utils.h" +#include "utils/op_desc_utils.h" #include "graph/utils/graph_utils.h" #include "debug/ge_op_types.h" #include "debug/ge_util.h" @@ -301,6 +302,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer } for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + auto out_dims = output_tensor->GetShape().GetDims(); + auto out_dtype = output_tensor->GetDataType(); ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); output_tensor->SetOriginShape(output_tensor->GetShape()); output_tensor->SetOriginDataType(output_tensor->GetDataType()); @@ -320,6 +323,35 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); continue; } + // check shape and dtype continuity. do not stop process + auto peer_input_dims = peer_input_desc->GetShape().GetDims(); + auto peer_input_dtype = peer_input_desc->GetDataType(); + if (out_dtype != peer_input_dtype) { + GELOGW( + "current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th " + "input_dtype is [%s].The two dtype should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(), + peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), + TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str()); + } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) { + string out_shape_str, peer_in_shape_str; + out_shape_str += "["; + for (int64_t dim : out_dims) { + out_shape_str += std::to_string(dim) + " "; + } + out_shape_str += "]"; + peer_in_shape_str += "["; + for (int64_t dim : peer_input_dims) { + peer_in_shape_str += std::to_string(dim) + " "; + } + peer_in_shape_str += "]"; + + GELOGW( + "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " + "input_shape is [%s].The two shape should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(), + peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str()); + } GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), output_tensor->GetDataType(), output_tensor->GetOriginDataType()); @@ -341,15 +373,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, - uint32_t index) { + uint32_t num) { if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); + GELOGE(GRAPH_FAILED, "Input node is null"); return GRAPH_FAILED; } GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - OpDescPtr op_desc = node->op_; - for (size_t i = op_desc->GetInputsSize(); i < index; ++i) { + const auto &op_desc = node->GetOpDesc(); + for (size_t i = op_desc->GetInputsSize(); i < num; ++i) { if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Add input desc failed"); return GRAPH_FAILED; @@ -357,7 +389,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInpu auto anchor = ComGraphMakeShared(node, i); if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); + GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed."); return GRAPH_FAILED; } node->in_data_anchors_.push_back(anchor); @@ -367,22 +399,81 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInpu } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, - uint32_t index) { + uint32_t num) { if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); + GELOGE(GRAPH_FAILED, "Input node is null"); return GRAPH_FAILED; } - OpDescPtr op_desc = node->op_; - op_desc->RemoveInputDesc(index); + const auto &op_desc = node->GetOpDesc(); + while (op_desc->GetInputsSize() > num) { + if (!OpDescUtils::ClearInputDesc(op_desc, num)) { + return GRAPH_FAILED; + } + } - while (node->in_data_anchors_.size() > index) { + auto input_names = op_desc->GetAllInputName(); + (void)op_desc->UpdateInputName(input_names); + auto is_input_const = op_desc->GetIsInputConst(); + is_input_const.resize(num); + op_desc->SetIsInputConst(is_input_const); + + while (node->in_data_anchors_.size() > num) { node->in_data_anchors_.pop_back(); } return GRAPH_SUCCESS; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, + uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); + const OpDescPtr &op_desc = node->GetOpDesc(); + for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) { + if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add output desc failed"); + return GRAPH_FAILED; + } + + auto anchor = ComGraphMakeShared(node, i); + if (anchor == nullptr) { + GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed."); + return GRAPH_FAILED; + } + node->out_data_anchors_.push_back(anchor); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, + uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + const auto &op_desc = node->GetOpDesc(); + auto output_names = op_desc->GetAllOutputName(); + while (op_desc->GetOutputsSize() > num) { + if (!OpDescUtils::ClearOutputDesc(op_desc, num)) { + return GRAPH_FAILED; + } + } + (void)op_desc->UpdateOutputName(output_names); + + while (node->out_data_anchors_.size() > num) { + node->out_data_anchors_.pop_back(); + } + + return GRAPH_SUCCESS; +} + bool NodeUtils::IsInNodesEmpty(const Node &node) { for (const auto &in_anchor : node.in_data_anchors_) { if (in_anchor != nullptr) { @@ -488,11 +579,22 @@ std::string NodeUtils::GetNodeType(const Node &node) { if (node.GetType() != FRAMEWORKOP) { return node.GetType(); } + std::string type; (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); return type; } +std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); } + +graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) { + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) { + return GRAPH_SUCCESS; +} + ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { auto op_desc = node.GetOpDesc(); if (op_desc == nullptr) { @@ -544,16 +646,17 @@ bool NodeUtils::IsSubgraphInput(const NodePtr &node) { if (parent_op_desc == nullptr) { return false; } - if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { - bool is_unknown_shape = false; - (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); - if (is_unknown_shape) return false; - } - if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) && - kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 && - kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) { - return false; + // dynamic shape unknown graph false + // dynamic shape known graph with functional subgraph maybe true + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { + return false; + } else { + if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { + return false; + } + } } return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); @@ -576,15 +679,13 @@ bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { } if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { - bool is_unknown_shape = false; - (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); - if (is_unknown_shape) return false; - } - - if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) && - kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 && - kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) { - return false; + if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { + return false; + } else { + if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { + return false; + } + } } for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { @@ -601,16 +702,14 @@ bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { /// @param [in] node /// @return Node /// -NodePtr NodeUtils::GetParentInput(const NodePtr &node) { - GE_CHECK_NOTNULL_EXEC(node, return nullptr); - +NodePtr NodeUtils::GetParentInput(const Node &node) { uint32_t parent_index = 0; - if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { return nullptr; } // Subgraph Data Node, check for constant input. - const ComputeGraphPtr &graph = node->GetOwnerComputeGraph(); + const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); GE_CHECK_NOTNULL_EXEC(graph, return nullptr); const NodePtr &parent_node = graph->GetParentNode(); @@ -625,6 +724,26 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return peer_out_anchor->GetOwnerNode(); } +NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); } + +/// +/// @brief Get is dynamic shape graph from node. +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsDynamicShape(const Node &node) { + const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (graph == nullptr) { + return false; + } + + bool is_dynamic_shape = false; + (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); + return is_dynamic_shape; +} + +bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); } + /// /// @brief Check is varying_input for while node /// @param [in] node: Data node for subgraph @@ -678,27 +797,22 @@ bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { /// @param [out] string /// @return bool /// -bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) { - GE_CHECK_NOTNULL_EXEC(in_node, return false); +bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) { + if (node == nullptr) { + return false; + } - if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { - op_type = in_node->GetType(); + if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { + type = node->GetType(); return true; } - if (in_node->GetType() == DATA) { - std::string const_type; - if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { - return false; - } - - if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) { - op_type = const_type; - return true; - } + if (node->GetType() != DATA) { + return false; // not subgraph input node } - return false; + const auto &parent = GetParentInput(node); + return GetConstOpType(parent, type); } /// @@ -809,7 +923,7 @@ vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { return out_data_node_vec; } -NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { +NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) { if (node.GetInDataAnchor(index) == nullptr) { return nullptr; } @@ -819,12 +933,13 @@ NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); } -vector NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { - vector out_data_nodes; +vector> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) { + vector> out_data_nodes; auto out_data_anchor = node.GetOutDataAnchor(index); if (out_data_anchor == nullptr) { return out_data_nodes; } + for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { if (peer_in_anchor == nullptr) { continue; @@ -832,8 +947,10 @@ vector NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { if (peer_in_anchor->GetOwnerNode() == nullptr) { continue; } - out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); + out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode())); } return out_data_nodes; } + +ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); } } // namespace ge diff --git a/src/common/graph/utils/op_desc_utils.cc b/src/common/graph/utils/op_desc_utils.cc index 7a52a7f8..92883877 100644 --- a/src/common/graph/utils/op_desc_utils.cc +++ b/src/common/graph/utils/op_desc_utils.cc @@ -438,6 +438,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils:: if (switch_input.size() > 0) { ret.insert(ret.end(), switch_input.begin(), switch_input.end()); } + } else if (in_node->GetType() == DATA) { + auto parent = NodeUtils::GetParentInput(in_node); + if ((parent != nullptr) && (parent->GetType() == CONSTANT)) { + ret.push_back(parent); + } } } return ret; diff --git a/src/common/graph/utils/type_utils.cc b/src/common/graph/utils/type_utils.cc index 5215b141..2efc530e 100644 --- a/src/common/graph/utils/type_utils.cc +++ b/src/common/graph/utils/type_utils.cc @@ -244,6 +244,21 @@ static const std::map kFmkTypeToString = { {domi::ANDROID_NN, "android_nn"}, {domi::ONNX, "onnx"}, {domi::FRAMEWORK_RESERVED, "framework_reserved"}, }; +static const std::map kImplyTypeToString = { + {domi::ImplyType::BUILDIN, "buildin"}, {domi::ImplyType::TVM, "tvm"}, {domi::ImplyType::CUSTOM, "custom"}, + {domi::ImplyType::AI_CPU, "ai_cpu"}, {domi::ImplyType::CCE, "cce"}, {domi::ImplyType::GELOCAL, "gelocal"}, + {domi::ImplyType::HCCL, "hccl"}, {domi::ImplyType::INVALID, "invalid"}}; + +std::string TypeUtils::ImplyTypeToSerialString(domi::ImplyType imply_type) { + auto it = kImplyTypeToString.find(imply_type); + if (it != kImplyTypeToString.end()) { + return it->second; + } else { + GELOGE(GRAPH_FAILED, "ImplyTypeToSerialString: imply_type not support %u", imply_type); + return "UNDEFINED"; + } +} + bool TypeUtils::IsDataTypeValid(DataType dt) { uint32_t num = static_cast(dt); GE_CHK_BOOL_EXEC((num <= DT_UNDEFINED), return false, "The DataType is invalid"); diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt index a527bc1f..922502e6 100755 --- a/src/ge/CMakeLists.txt +++ b/src/ge/CMakeLists.txt @@ -56,6 +56,9 @@ include_directories(${CMAKE_BINARY_DIR}/proto/ge) # need to remove dependencies on pb files later file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "client/ge_api.cc" + "common/dump/dump_manager.cc" + "common/dump/dump_properties.cc" + "common/dump/dump_op.cc" "common/formats/format_transfers/*.cc" "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" @@ -124,6 +127,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" "graph/preprocess/multi_batch_copy_graph.cc" + "graph/preprocess/multi_batch_options.cc" "host_kernels/add_kernel.cc" "host_kernels/broadcast_args_kernel.cc" "host_kernels/broadcast_gradient_args_kernel.cc" @@ -138,6 +142,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "host_kernels/floormod_kernel.cc" "host_kernels/gather_v2_kernel.cc" "host_kernels/greater_kernel.cc" + "host_kernels/identity_kernel.cc" "host_kernels/kernel_utils.cc" "host_kernels/maximum_kernel.cc" "host_kernels/mul_kernel.cc" @@ -172,10 +177,18 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "hybrid/node_executor/aicpu/aicpu_node_executor.cc" "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" "hybrid/node_executor/controlop/control_op_executor.cc" + "hybrid/node_executor/ge_local/ge_local_node_executor.cc" "hybrid/node_executor/hccl/hccl_node_executor.cc" "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" + "hybrid/node_executor/host_cpu/host_cpu_node_executor.cc" + "hybrid/node_executor/host_cpu/kernel_factory.cc" + "hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" "hybrid/node_executor/node_executor.cc" "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" + "hybrid/node_executor/rts/rts_node_executor.cc" "hybrid/node_executor/task_context.cc" "init/gelib.cc" "model/ge_model.cc" @@ -215,6 +228,9 @@ target_link_libraries(ge_runner ######### libge_compiler.so ############# # need to remove dependencies on pb files later file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "common/dump/dump_properties.cc" + "common/dump/dump_manager.cc" + "common/dump/dump_op.cc" "common/formats/format_transfers/*.cc" "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" @@ -274,6 +290,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" "graph/preprocess/multi_batch_copy_graph.cc" + "graph/preprocess/multi_batch_options.cc" "host_kernels/add_kernel.cc" "host_kernels/broadcast_args_kernel.cc" "host_kernels/broadcast_gradient_args_kernel.cc" @@ -288,6 +305,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "host_kernels/floormod_kernel.cc" "host_kernels/gather_v2_kernel.cc" "host_kernels/greater_kernel.cc" + "host_kernels/identity_kernel.cc" "host_kernels/kernel_utils.cc" "host_kernels/maximum_kernel.cc" "host_kernels/mul_kernel.cc" diff --git a/src/ge/client/ge_api.cc b/src/ge/client/ge_api.cc index 120c144a..9eb15ee4 100644 --- a/src/ge/client/ge_api.cc +++ b/src/ge/client/ge_api.cc @@ -390,6 +390,22 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector &var_names, std::vector &var_values) { + auto instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); + return FAILED; + } + GELOGT(TRACE_RUNNING, "Get Variables"); + Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, var_names, var_values); + if (ret != SUCCESS) { + GELOGE(ret, "SessionManager RunGraphAsync failed"); + return FAILED; + } + return SUCCESS; +} + bool Session::IsGraphNeedRebuild(uint32_t graph_id) { return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); } diff --git a/src/ge/common/dump/dump_manager.cc b/src/ge/common/dump/dump_manager.cc new file mode 100644 index 00000000..d6783830 --- /dev/null +++ b/src/ge/common/dump/dump_manager.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/dump/dump_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" + +namespace { +const char *const kDumpOFF = "OFF"; +const char *const kDumpoff = "off"; +const char *const kDumpOn = "on"; +} // namespace +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetInstance() { + static DumpManager instance; + return instance; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpManager::SetDumpConf(const DumpConfig &dump_config) { + std::lock_guard lock(mutex_); + dump_properties_.ClearDumpPropertyValue(); + dump_properties_.ClearDumpInfo(); + std::string dump_status; + std::string dump_path; + std::string dump_mode; + std::string dump_op_switch; + + if (dump_config.dump_status.empty()) { + GELOGI("Dump does not open"); + return SUCCESS; + } + + dump_status = dump_config.dump_status; + GELOGI("Dump status is %s", dump_status.c_str()); + if (dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) { + dump_properties_.ClearDumpPropertyValue(); + return SUCCESS; + } + dump_op_switch = dump_config.dump_op_switch; + if (dump_op_switch == kDumpoff && dump_config.dump_list.empty()) { + GELOGE(PARAM_INVALID, "Dump list is invalid,dump_op_switch is %s", dump_op_switch.c_str()); + return PARAM_INVALID; + } + + if (!dump_config.dump_list.empty()) { + for (auto model_dump : dump_config.dump_list) { + std::string model_name = model_dump.model_name; + GELOGI("Dump model is %s", model_name.c_str()); + std::set dump_layers; + for (auto layer : model_dump.layers) { + GELOGI("Dump layer is %s in model", layer.c_str()); + dump_layers.insert(layer); + } + dump_properties_.AddPropertyValue(model_name, dump_layers); + } + if (dump_op_switch == kDumpOn) { + GELOGI("Start to dump model and single op,dumo op switch is %s", dump_op_switch.c_str()); + } else { + GELOGI("Only dump model,dump op switch is %s", dump_op_switch.c_str()); + } + } else { + GELOGI("Only dump single op,dumo op switch is %s", dump_op_switch.c_str()); + } + + dump_path = dump_config.dump_path; + if (dump_path.empty()) { + GELOGE(PARAM_INVALID, "Dump path is empty"); + return PARAM_INVALID; + } + + if (dump_path[dump_path.size() - 1] != '/') { + dump_path = dump_path + "/"; + } + dump_path = dump_path + CurrentTimeInStr() + "/"; + GELOGI("Dump path is %s", dump_path.c_str()); + dump_properties_.SetDumpPath(dump_path); + + dump_mode = dump_config.dump_mode; + GELOGI("Dump mode is %s", dump_mode.c_str()); + dump_properties_.SetDumpMode(dump_mode); + + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpManager::IsDumpOpen() { + std::lock_guard lock(mutex_); + if (!dump_properties_.GetDumpPath().empty()) { + return true; + } + return false; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const DumpProperties &DumpManager::GetDumpProperties() { + std::lock_guard lock(mutex_); + return dump_properties_; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpManager::SetModelName(const std::string &model_name) { + std::lock_guard lock(mutex_); + model_name_ = model_name; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpManager::GetModelName() { + std::lock_guard lock(mutex_); + return model_name_; +} +} // namespace ge diff --git a/src/ge/common/dump/dump_manager.h b/src/ge/common/dump/dump_manager.h new file mode 100644 index 00000000..ee38cef1 --- /dev/null +++ b/src/ge/common/dump/dump_manager.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_DUMP_DUMP_MANAGER_H_ +#define GE_COMMON_DUMP_DUMP_MANAGER_H_ + +#include + +#include "common/dump/dump_properties.h" +#include "common/ge_types.h" + +namespace ge { +class DumpManager { + public: + static DumpManager &GetInstance(); + + Status SetDumpConf(const DumpConfig &dump_config); + bool IsDumpOpen(); + const DumpProperties &GetDumpProperties(); + void SetModelName(const std::string &model_name); + const std::string &GetModelName(); + + private: + DumpProperties dump_properties_; + std::mutex mutex_; + std::string model_name_; +}; +} // namespace ge +#endif // GE_COMMON_DUMP_DUMP_MANAGER_H_ diff --git a/src/ge/common/dump/dump_op.cc b/src/ge/common/dump/dump_op.cc new file mode 100644 index 00000000..a36204dd --- /dev/null +++ b/src/ge/common/dump/dump_op.cc @@ -0,0 +1,255 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/dump/dump_op.h" + +#include "aicpu/common/aicpu_task_struct.h" +#include "common/dump/dump_manager.h" +#include "common/ge/datatype_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/anchor.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/utils/tensor_utils.h" +#include "proto/ge_ir.pb.h" +#include "proto/op_mapping_info.pb.h" +#include "runtime/mem.h" + +namespace { +const uint32_t kAicpuLoadFlag = 1; +const char *const kDumpOutput = "output"; +const char *const kDumpInput = "input"; +const char *const kDumpAll = "all"; +const char *const kDumpKernelsDumpOp = "DumpDataInfo"; +} // namespace + +namespace ge { +DumpOp::~DumpOp() { + if (proto_dev_mem_ != nullptr) { + (void)rtFree(proto_dev_mem_); + } + if (proto_size_dev_mem_ != nullptr) { + (void)rtFree(proto_size_dev_mem_); + } + proto_dev_mem_ = nullptr; + proto_size_dev_mem_ = nullptr; +} + +void DumpOp::SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond) { + global_step_ = reinterpret_cast(global_step); + loop_per_iter_ = reinterpret_cast(loop_per_iter); + loop_cond_ = reinterpret_cast(loop_cond); +} + +void DumpOp::SetDynamicModelInfo(const string &dynamic_model_name, uint32_t dynamic_model_id) { + dynamic_model_name_ = dynamic_model_name; + dynamic_model_id_ = dynamic_model_id; +} + +static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uintptr_t loop_cond, + aicpu::dump::OpMappingInfo &op_mapping_info) { + if (step_id != 0) { + GELOGI("step_id exists."); + op_mapping_info.set_step_id_addr(static_cast(step_id)); + } else { + GELOGI("step_id is null."); + } + + if (loop_per_iter != 0) { + GELOGI("loop_per_iter exists."); + op_mapping_info.set_iterations_per_loop_addr(static_cast(loop_per_iter)); + } else { + GELOGI("loop_per_iter is null."); + } + + if (loop_cond != 0) { + GELOGI("loop_cond exists."); + op_mapping_info.set_loop_cond_addr(static_cast(loop_cond)); + } else { + GELOGI("loop_cond is null."); + } +} + +Status DumpOp::DumpOutput(aicpu::dump::Task &task) { + GELOGI("Start dump output in Launch dump op"); + const auto &output_descs = op_desc_->GetAllOutputsDesc(); + for (size_t i = 0; i < output_descs.size(); ++i) { + aicpu::dump::Output output; + output.set_data_type(static_cast(DataTypeUtil::GetIrDataType(output_descs.at(i).GetDataType()))); + output.set_format(static_cast(output_descs.at(i).GetFormat())); + for (auto dim : output_descs.at(i).GetShape().GetDims()) { + output.mutable_shape()->add_dim(dim); + } + int64_t output_size = 0; + if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get output size filed"); + return PARAM_INVALID; + } + GELOGD("Get output size in lanch dump op is %ld", output_size); + output.set_size(output_size); + output.set_address(static_cast(output_addrs_[i])); + task.mutable_output()->Add(std::move(output)); + } + return SUCCESS; +} + +Status DumpOp::DumpInput(aicpu::dump::Task &task) { + GELOGI("Start dump input in Launch dump op"); + const auto &input_descs = op_desc_->GetAllInputsDesc(); + for (size_t i = 0; i < input_descs.size(); ++i) { + aicpu::dump::Input input; + input.set_data_type(static_cast(DataTypeUtil::GetIrDataType(input_descs.at(i).GetDataType()))); + input.set_format(static_cast(input_descs.at(i).GetFormat())); + + for (auto dim : input_descs.at(i).GetShape().GetDims()) { + input.mutable_shape()->add_dim(dim); + } + int64_t input_size = 0; + if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get output size filed"); + return PARAM_INVALID; + } + GELOGD("Get input size in lanch dump op is %ld", input_size); + input.set_size(input_size); + input.set_address(static_cast(input_addrs_[i])); + task.mutable_input()->Add(std::move(input)); + } + return SUCCESS; +} + +void DumpOp::SetDumpInfo(const DumpProperties &dump_properties, const OpDescPtr &op_desc, vector input_addrs, + vector output_addrs, rtStream_t stream) { + dump_properties_ = dump_properties; + op_desc_ = op_desc; + input_addrs_ = input_addrs; + output_addrs_ = output_addrs; + stream_ = stream; +} + +Status DumpOp::ExecutorDumpOp(aicpu::dump::OpMappingInfo &op_mapping_info) { + std::string proto_msg; + size_t proto_size = op_mapping_info.ByteSizeLong(); + bool ret = op_mapping_info.SerializeToString(&proto_msg); + if (!ret || proto_size == 0) { + GELOGE(FAILED, "Protobuf serialize failed,proto_size is %zu", proto_size); + return FAILED; + } + + rtError_t rt_ret = rtMalloc(&proto_dev_mem_, proto_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + rt_ret = rtMemcpy(proto_dev_mem_, proto_size, proto_msg.c_str(), proto_size, RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMemcpy failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + rt_ret = rtMalloc(&proto_size_dev_mem_, sizeof(size_t), RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + rt_ret = rtMemcpy(proto_size_dev_mem_, sizeof(size_t), &proto_size, sizeof(size_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMemcpy failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + constexpr int32_t ioAddrNum = 2; + constexpr uint32_t argsSize = sizeof(aicpu::AicpuParamHead) + ioAddrNum * sizeof(uint64_t); + char args[argsSize] = {0}; + auto paramHead = reinterpret_cast(args); + paramHead->length = argsSize; + paramHead->ioAddrNum = ioAddrNum; + auto ioAddr = reinterpret_cast(args + sizeof(aicpu::AicpuParamHead)); + ioAddr[0] = reinterpret_cast(proto_dev_mem_); + ioAddr[1] = reinterpret_cast(proto_size_dev_mem_); + rt_ret = rtCpuKernelLaunch(nullptr, kDumpKernelsDumpOp, + 1, // blockDim default 1 + args, argsSize, + nullptr, // no need smDesc + stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtCpuKernelLaunch failed,rt_ret:0x%X", rt_ret); + return rt_ret; + } + GELOGI("Kernel launch dump op success"); + return SUCCESS; +} + +Status DumpOp::LaunchDumpOp() { + GELOGI("Start to launch dump op %s", op_desc_->GetName().c_str()); + int32_t device_id = 0; + rtError_t rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE || device_id < 0) { + GELOGE(RT_FAILED, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); + return RT_FAILED; + } + aicpu::dump::OpMappingInfo op_mapping_info; + auto dump_path = dump_properties_.GetDumpPath() + std::to_string(device_id) + "/"; + op_mapping_info.set_dump_path(dump_path); + op_mapping_info.set_flag(kAicpuLoadFlag); + op_mapping_info.set_dump_step(dump_properties_.GetDumpStep()); + if (!dynamic_model_name_.empty()) { + op_mapping_info.set_model_name(dynamic_model_name_); + op_mapping_info.set_model_id(dynamic_model_id_); + } + SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); + GELOGI("Dump step is %s ,dump path is %s ,in Launch dump op", dump_properties_.GetDumpStep().c_str(), + dump_path.c_str()); + + aicpu::dump::Task task; + task.mutable_op()->set_op_name(op_desc_->GetName()); + task.mutable_op()->set_op_type(op_desc_->GetType()); + if (dump_properties_.GetDumpMode() == kDumpOutput) { + if (DumpOutput(task) != SUCCESS) { + GELOGE(FAILED, "Dump output failed"); + return FAILED; + } + op_mapping_info.mutable_task()->Add(std::move(task)); + } + if (dump_properties_.GetDumpMode() == kDumpInput) { + if (DumpInput(task) != SUCCESS) { + GELOGE(FAILED, "Dump input failed"); + return FAILED; + } + op_mapping_info.mutable_task()->Add(std::move(task)); + } + if (dump_properties_.GetDumpMode() == kDumpAll) { + auto ret = DumpOutput(task); + if (ret != SUCCESS) { + GELOGE(FAILED, "Dump output failed when in dumping all"); + return FAILED; + } + ret = DumpInput(task); + if (ret != SUCCESS) { + GELOGE(FAILED, "Dump input failed when in dumping all"); + return FAILED; + } + op_mapping_info.mutable_task()->Add(std::move(task)); + } + auto ret = ExecutorDumpOp(op_mapping_info); + if (ret != SUCCESS) { + GELOGE(ret, "Executor dump op failed"); + return ret; + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/common/dump/dump_op.h b/src/ge/common/dump/dump_op.h new file mode 100644 index 00000000..b3042245 --- /dev/null +++ b/src/ge/common/dump/dump_op.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_DUMP_DUMP_OP_H_ +#define GE_COMMON_DUMP_DUMP_OP_H_ + +#include + +#include "common/ge_inner_error_codes.h" +#include "common/properties_manager.h" +#include "proto/op_mapping_info.pb.h" +#include "runtime/stream.h" + +namespace ge { +class DumpOp { + public: + DumpOp() = default; + ~DumpOp(); + + void SetDumpInfo(const DumpProperties &dump_properties, const OpDescPtr &op_desc, vector input_addrs, + vector output_addrs, rtStream_t stream); + Status LaunchDumpOp(); + void SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond); + void SetDynamicModelInfo(const string &dynamic_model_name, uint32_t dynamic_model_id); + + private: + Status ExecutorDumpOp(aicpu::dump::OpMappingInfo &op_mapping_info); + Status DumpOutput(aicpu::dump::Task &task); + Status DumpInput(aicpu::dump::Task &task); + + DumpProperties dump_properties_; + OpDescPtr op_desc_; + std::vector input_addrs_; + std::vector output_addrs_; + + void *proto_dev_mem_ = nullptr; + void *proto_size_dev_mem_ = nullptr; + rtStream_t stream_; + uintptr_t global_step_; + uintptr_t loop_per_iter_; + uintptr_t loop_cond_; + + std::string dynamic_model_name_; + std::uint32_t dynamic_model_id_; +}; +} // namespace ge + +#endif // GE_COMMON_DUMP_DUMP_OP_H_ diff --git a/src/ge/common/dump/dump_properties.cc b/src/ge/common/dump/dump_properties.cc new file mode 100644 index 00000000..cbf3697d --- /dev/null +++ b/src/ge/common/dump/dump_properties.cc @@ -0,0 +1,238 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/dump/dump_properties.h" + +#include +#include + +#include "common/ge/ge_util.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_types.h" +#include "framework/common/types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" +#include "graph/utils/attr_utils.h" + +namespace { +const std::string kEnableFlag = "1"; + +const uint32_t kAicoreOverflow = (0x1 << 0); +const uint32_t kAtomicOverflow = (0x1 << 1); +const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); +} // namespace +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { + CopyFrom(other); +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=( + const DumpProperties &other) { + CopyFrom(other); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() { + enable_dump_.clear(); + enable_dump_debug_.clear(); + dump_path_.clear(); + dump_step_.clear(); + dump_mode_.clear(); + is_op_debug_ = false; + op_debug_mode_ = 0; + + std::string enable_dump; + (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); + enable_dump_ = enable_dump; + + std::string enable_dump_debug; + (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); + enable_dump_debug_ = enable_dump_debug; + + if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { + std::string dump_path; + if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { + if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { + dump_path = dump_path + "/"; + } + dump_path = dump_path + CurrentTimeInStr() + "/"; + GELOGI("Get dump path %s successfully", dump_path.c_str()); + SetDumpPath(dump_path); + } else { + GELOGW("Dump path is not set"); + } + } + + if (enable_dump_ == kEnableFlag) { + std::string dump_step; + if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { + GELOGD("Get dump step %s successfully", dump_step.c_str()); + SetDumpStep(dump_step); + } + string dump_mode; + if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { + GELOGD("Get dump mode %s successfully", dump_mode.c_str()); + SetDumpMode(dump_mode); + } + AddPropertyValue(DUMP_ALL_MODEL, {}); + } + + SetDumpDebugOptions(); +} + +// The following is the new dump scenario of the fusion operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue( + const std::string &model, const std::set &layers) { + for (const std::string &layer : layers) { + GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); + } + + model_dump_properties_map_[model] = layers; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) { + auto iter = model_dump_properties_map_.find(model); + if (iter != model_dump_properties_map_.end()) { + model_dump_properties_map_.erase(iter); + } +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpPropertyValue() { + model_dump_properties_map_.clear(); +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpInfo() { + enable_dump_.clear(); + enable_dump_debug_.clear(); + dump_path_.clear(); + dump_step_.clear(); + dump_mode_.clear(); + is_op_debug_ = false; + op_debug_mode_ = 0; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetAllDumpModel() const { + std::set model_list; + for (auto &iter : model_dump_properties_map_) { + model_list.insert(iter.first); + } + + return model_list; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetPropertyValue( + const std::string &model) const { + auto iter = model_dump_properties_map_.find(model); + if (iter != model_dump_properties_map_.end()) { + return iter->second; + } + return {}; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump( + const std::string &model, const std::string &om_name, const std::string &op_name) const { + // if dump all + if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { + return true; + } + + // if this model need dump + auto om_name_iter = model_dump_properties_map_.find(om_name); + auto model_name_iter = model_dump_properties_map_.find(model); + if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { + // if no dump layer info, dump all layer in this model + auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; + if (model_iter->second.empty()) { + return true; + } + + return model_iter->second.find(op_name) != model_iter->second.end(); + } + + GELOGD("Model %s is not seated to be dump.", model.c_str()); + return false; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) { + dump_path_ = path; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpPath() const { + return dump_path_; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) { + dump_step_ = step; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpStep() const { + return dump_step_; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) { + dump_mode_ = mode; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpMode() const { + return dump_mode_; +} + +void DumpProperties::CopyFrom(const DumpProperties &other) { + if (&other != this) { + enable_dump_ = other.enable_dump_; + enable_dump_debug_ = other.enable_dump_debug_; + dump_path_ = other.dump_path_; + dump_step_ = other.dump_step_; + dump_mode_ = other.dump_mode_; + + model_dump_properties_map_ = other.model_dump_properties_map_; + is_op_debug_ = other.is_op_debug_; + op_debug_mode_ = other.op_debug_mode_; + } +} + +void DumpProperties::SetDumpDebugOptions() { + if (enable_dump_debug_ == kEnableFlag) { + std::string dump_debug_mode; + if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { + GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); + } else { + GELOGW("Dump debug mode is not set."); + return; + } + + if (dump_debug_mode == OP_DEBUG_AICORE) { + GELOGD("ge.exec.dumpDebugMode=aicore_overflow, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAicoreOverflow; + } else if (dump_debug_mode == OP_DEBUG_ATOMIC) { + GELOGD("ge.exec.dumpDebugMode=atomic_overflow, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAtomicOverflow; + } else if (dump_debug_mode == OP_DEBUG_ALL) { + GELOGD("ge.exec.dumpDebugMode=all, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAllOverflow; + } else { + GELOGW("ge.exec.dumpDebugMode is invalid."); + } + } else { + GELOGI("ge.exec.enableDumpDebug is false or is not set."); + } +} +} // namespace ge diff --git a/src/ge/common/dump/dump_properties.h b/src/ge/common/dump/dump_properties.h new file mode 100644 index 00000000..a397cac4 --- /dev/null +++ b/src/ge/common/dump/dump_properties.h @@ -0,0 +1,86 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_DUMP_DUMP_PROPERTIES_H_ +#define GE_COMMON_DUMP_DUMP_PROPERTIES_H_ + +#include +#include +#include +#include + +namespace ge { +class DumpProperties { + public: + DumpProperties() = default; + + ~DumpProperties() = default; + + DumpProperties(const DumpProperties &dump); + + DumpProperties &operator=(const DumpProperties &dump); + + void InitByOptions(); + + void AddPropertyValue(const std::string &model, const std::set &layers); + + void DeletePropertyValue(const std::string &model); + + void ClearDumpPropertyValue(); + + void ClearDumpInfo(); + + std::set GetAllDumpModel() const; + + std::set GetPropertyValue(const std::string &model) const; + + bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name) const; + + void SetDumpPath(const std::string &path); + + const std::string &GetDumpPath() const; + + void SetDumpStep(const std::string &step); + + const std::string &GetDumpStep() const; + + void SetDumpMode(const std::string &mode); + + const std::string &GetDumpMode() const; + + bool IsOpDebugOpen() const { return is_op_debug_; } + + uint32_t GetOpDebugMode() const { return op_debug_mode_; } + + private: + void CopyFrom(const DumpProperties &other); + + void SetDumpDebugOptions(); + + std::string enable_dump_; + std::string enable_dump_debug_; + + std::string dump_path_; + std::string dump_step_; + std::string dump_mode_; + std::map> model_dump_properties_map_; + + bool is_op_debug_ = false; + uint32_t op_debug_mode_ = 0; +}; +} // namespace ge + +#endif // GE_COMMON_DUMP_DUMP_PROPERTIES_H_ \ No newline at end of file diff --git a/src/ge/common/ge/datatype_util.cc b/src/ge/common/ge/datatype_util.cc index 14635c14..f2ff12cb 100644 --- a/src/ge/common/ge/datatype_util.cc +++ b/src/ge/common/ge/datatype_util.cc @@ -15,23 +15,54 @@ */ #include "common/ge/datatype_util.h" +#include "proto/ge_ir.pb.h" #include namespace { const std::vector kEmptyDatatypeVector; std::map> g_translatable_data_type = { - // key:src datatype, value:dst datatype - {ge::DT_FLOAT, {ge::DT_FLOAT16, ge::DT_FLOAT}}, - {ge::DT_BOOL, {ge::DT_INT32}}, - {ge::DT_FLOAT16, {ge::DT_FLOAT, ge::DT_FLOAT16}}, - {ge::DT_INT64, {ge::DT_INT32}}}; + // key:src datatype, value:dst datatype + {ge::DT_FLOAT, {ge::DT_FLOAT16, ge::DT_FLOAT}}, + {ge::DT_BOOL, {ge::DT_INT32}}, + {ge::DT_FLOAT16, {ge::DT_FLOAT, ge::DT_FLOAT16}}, + {ge::DT_INT64, {ge::DT_INT32}}}; std::map> g_reverse_translatable_data_type = { - // key:dst datatype,value:src datatype - {ge::DT_FLOAT16, {ge::DT_FLOAT, ge::DT_FLOAT16}}, - {ge::DT_INT32, {ge::DT_BOOL, ge::DT_INT64}}, - {ge::DT_FLOAT, {ge::DT_FLOAT16, ge::DT_FLOAT}}}; + // key:dst datatype,value:src datatype + {ge::DT_FLOAT16, {ge::DT_FLOAT, ge::DT_FLOAT16}}, + {ge::DT_INT32, {ge::DT_BOOL, ge::DT_INT64}}, + {ge::DT_FLOAT, {ge::DT_FLOAT16, ge::DT_FLOAT}}}; + +static const std::map g_dump_data_type_map = { + // key:ge datatype,value:proto datatype + {ge::DT_UNDEFINED, ge::proto::DT_UNDEFINED}, + {ge::DT_FLOAT, ge::proto::DT_FLOAT}, + {ge::DT_FLOAT16, ge::proto::DT_FLOAT16}, + {ge::DT_INT8, ge::proto::DT_INT8}, + {ge::DT_UINT8, ge::proto::DT_UINT8}, + {ge::DT_INT16, ge::proto::DT_INT16}, + {ge::DT_UINT16, ge::proto::DT_UINT16}, + {ge::DT_INT32, ge::proto::DT_INT32}, + {ge::DT_INT64, ge::proto::DT_INT64}, + {ge::DT_UINT32, ge::proto::DT_UINT32}, + {ge::DT_UINT64, ge::proto::DT_UINT64}, + {ge::DT_BOOL, ge::proto::DT_BOOL}, + {ge::DT_DOUBLE, ge::proto::DT_DOUBLE}, + {ge::DT_DUAL, ge::proto::DT_DUAL}, + {ge::DT_DUAL_SUB_INT8, ge::proto::DT_DUAL_SUB_INT8}, + {ge::DT_DUAL_SUB_UINT8, ge::proto::DT_DUAL_SUB_UINT8}, + {ge::DT_COMPLEX64, ge::proto::DT_COMPLEX64}, + {ge::DT_COMPLEX128, ge::proto::DT_COMPLEX128}, + {ge::DT_QINT8, ge::proto::DT_QINT8}, + {ge::DT_QINT16, ge::proto::DT_QINT16}, + {ge::DT_QINT32, ge::proto::DT_QINT32}, + {ge::DT_QUINT8, ge::proto::DT_QUINT8}, + {ge::DT_QUINT16, ge::proto::DT_QUINT16}, + {ge::DT_RESOURCE, ge::proto::DT_RESOURCE}, + {ge::DT_STRING_REF, ge::proto::DT_STRING_REF}, + {ge::DT_STRING, ge::proto::DT_STRING}, +}; } // namespace namespace ge { @@ -67,4 +98,13 @@ const std::vector &DataTypeUtil::GetTranslatableDataTypesByDst(con return search->second; } + +int32_t DataTypeUtil::GetIrDataType(ge::DataType data_type) { + auto iter = g_dump_data_type_map.find(data_type); + if (iter == g_dump_data_type_map.end()) { + return static_cast(ge::proto::DT_UNDEFINED); + } + + return static_cast(iter->second); +} } // namespace ge diff --git a/src/ge/common/ge/datatype_util.h b/src/ge/common/ge/datatype_util.h index ee3fb74d..480b35e7 100644 --- a/src/ge/common/ge/datatype_util.h +++ b/src/ge/common/ge/datatype_util.h @@ -37,16 +37,17 @@ static const int32_t kGeSizeUint16 = sizeof(uint16_t); static const int32_t kGeSizeUint32 = sizeof(uint32_t); static std::map CONST_OPDATA_TYPE_SIZE_MAP = { - {ge::DT_FLOAT, kGeSizeFloat}, {ge::DT_FLOAT16, kGeSizeHalfFloat}, {ge::DT_INT8, kGeSizeInt8}, - {ge::DT_INT16, kGeSizeInt16}, {ge::DT_INT32, kGeSizeInt32}, {ge::DT_INT64, kGeSizeInt64}, - {ge::DT_UINT8, kGeSizeUint8}, {ge::DT_UINT16, kGeSizeUint16}, {ge::DT_UINT32, kGeSizeUint32}, - {ge::DT_UINT64, kGeSizeUint64}, {ge::DT_DOUBLE, kGeSizeDouble}, {ge::DT_BOOL, kGeSizeBool}}; + {ge::DT_FLOAT, kGeSizeFloat}, {ge::DT_FLOAT16, kGeSizeHalfFloat}, {ge::DT_INT8, kGeSizeInt8}, + {ge::DT_INT16, kGeSizeInt16}, {ge::DT_INT32, kGeSizeInt32}, {ge::DT_INT64, kGeSizeInt64}, + {ge::DT_UINT8, kGeSizeUint8}, {ge::DT_UINT16, kGeSizeUint16}, {ge::DT_UINT32, kGeSizeUint32}, + {ge::DT_UINT64, kGeSizeUint64}, {ge::DT_DOUBLE, kGeSizeDouble}, {ge::DT_BOOL, kGeSizeBool}}; class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY DataTypeUtil { public: static bool DataTypeTranslatable(const ge::DataType &src_out_data_type, const ge::DataType &dst_in_data_type); static const std::vector &GetTranslatableDataTypesBySrc(const ge::DataType &src_out_data_type); static const std::vector &GetTranslatableDataTypesByDst(const ge::DataType &dst_in_data_type); + static int32_t GetIrDataType(ge::DataType data_type); }; } // namespace ge #endif // GE_COMMON_GE_DATATYPE_UTIL_H_ diff --git a/src/ge/common/ge/tbe_plugin_manager.cc b/src/ge/common/ge/tbe_plugin_manager.cc index d651ced1..8a594cb9 100644 --- a/src/ge/common/ge/tbe_plugin_manager.cc +++ b/src/ge/common/ge/tbe_plugin_manager.cc @@ -187,8 +187,8 @@ void TBEPluginManager::LoadCustomOpLib() { std::vector registration_datas = domi::OpRegistry::Instance()->registrationDatas; GELOGI("The size of registration_datas is: %zu", registration_datas.size()); for (OpRegistrationData reg_data : registration_datas) { - GELOGD("Begin to register optype: %s, imply_type: %u", reg_data.GetOmOptype().c_str(), - static_cast(reg_data.GetImplyType())); + GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(), + TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); domi::OpRegistry::Instance()->Register(reg_data); } } diff --git a/src/ge/common/ge_common.mk b/src/ge/common/ge_common.mk index e99ff654..e913c8f5 100644 --- a/src/ge/common/ge_common.mk +++ b/src/ge/common/ge_common.mk @@ -36,7 +36,6 @@ GE_COMMON_LOCAL_SRC_FILES := \ properties_manager.cc \ types.cc\ model_parser/base.cc \ - model_parser/graph_parser_util.cc \ tbe_kernel_store.cc \ op/attr_value_util.cc \ op/ge_op_utils.cc \ diff --git a/src/ge/common/math/math_util.h b/src/ge/common/math/math_util.h index 08088eb1..86c62209 100644 --- a/src/ge/common/math/math_util.h +++ b/src/ge/common/math/math_util.h @@ -562,7 +562,6 @@ inline Status CheckUint64MulOverflow(uint64_t a, uint64_t b) { /// @return Status inline Status CheckFp16MulOverflow(fp16_t a, fp16_t b) { fp16_t result = static_cast(a) * static_cast(b); - printf("result: %u, 0x%x\n", result.val, result.val); if (FP16_IS_INVALID(result.val)) { return FAILED; } @@ -885,6 +884,23 @@ inline Status CheckInt32DivOverflow(int32_t a, int32_t b) { static_cast(b)); \ return INTERNAL_ERROR; \ } -} // namespace ge +#define FMK_FP16_ZEROCHECK(a) \ + if (fabs(a) < DBL_EPSILON) { \ + GELOGE(INTERNAL_ERROR, "fp16 %f can not be zero !", a); \ + return INTERNAL_ERROR; \ + } + +#define FMK_FLOAT_ZEROCHECK(a) \ + if (fabs(a) < FLT_EPSILON) { \ + GELOGE(INTERNAL_ERROR, "float %f can not be zero !", a); \ + return INTERNAL_ERROR; \ + } + +#define FMK_DOUBLE_ZEROCHECK(a) \ + if (fabs(a) < DBL_EPSILON) { \ + GELOGE(INTERNAL_ERROR, "double %lf can not be zero !", a); \ + return INTERNAL_ERROR; \ + } +} // namespace ge #endif // GE_COMMON_MATH_MATH_UTIL_H_ diff --git a/src/ge/common/model_parser/graph_parser_util.cc b/src/ge/common/model_parser/graph_parser_util.cc deleted file mode 100644 index 19f505c1..00000000 --- a/src/ge/common/model_parser/graph_parser_util.cc +++ /dev/null @@ -1,501 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph_parser_util.h" -#include -#include "common/auth/file_saver.h" -#include "common/convert/pb2json.h" -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/model_parser/base.h" -#include "common/model_saver.h" -#include "common/properties_manager.h" -#include "common/string_util.h" -#include "common/types.h" -#include "common/util.h" -#include "common/util/error_manager/error_manager.h" -#include "external/register/register_types.h" -#include "framework/common/debug/ge_log.h" -#include "framework/omg/parser/parser_inner_ctx.h" -#include "graph/compute_graph.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/optimize/common/params.h" -#include "graph/utils/type_utils.h" -#include "omg/omg_inner_types.h" -#include "omg/parser/model_parser.h" -#include "omg/parser/parser_factory.h" -#include "omg/parser/weights_parser.h" -#include "parser/common/pre_checker.h" -#include "proto/ge_ir.pb.h" -#include "register/op_registry.h" - -namespace ge { -namespace { -// The function is incomplete. Currently, only l2_optimize, off_optimize is supported. -const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; -const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; -const char *const kSplitError1 = "size not equal to 2 split by \":\""; -const char *const kEmptyError = "can not be empty"; -const char *const kFloatNumError = "exist float number"; -const char *const kDigitError = "is not digit"; -const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; -const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; -const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; - -vector SplitInputShape(const std::string &input_shape) { - vector shape_pair_vec; - size_t pos = input_shape.rfind(":"); - if (pos != std::string::npos) { - shape_pair_vec.emplace_back(input_shape.substr(0, pos)); - shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); - } - return shape_pair_vec; -} - -static std::map output_type_str_to_datatype = { - {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; - -static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) { - if ((s == "true") || (s == "false")) { - return true; - } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {atc_param, s}); - GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str()); - return false; - } -} - -bool CheckDigitStr(std::string &str) { - for (char c : str) { - if (!isdigit(c)) { - GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str()); - return false; - } - } - return true; -} - -Status StringToInt(std::string &str, int32_t &value) { - try { - if (!CheckDigitStr(str)) { - GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--output_type", str, "is not positive integer"}); - return PARAM_INVALID; - } - value = stoi(str); - } catch (std::invalid_argument &) { - GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str}); - return PARAM_INVALID; - } catch (std::out_of_range &) { - GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str}); - return PARAM_INVALID; - } - return SUCCESS; -} - -Status VerifyOutputTypeAndOutNodes(std::vector &out_type_vec) { - std::vector> user_out_nodes = domi::GetContext().user_out_nodes; - std::set out_nodes_info; - for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { - // out_nodes set should include output_type and output_format - std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second); - out_nodes_info.emplace(tmp); - } - for (uint32_t i = 0; i < out_type_vec.size(); ++i) { - if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--output_type", out_type_vec[i], kOutputTypeError}); - GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); - return domi::FAILED; - } - } - return domi::SUCCESS; -} - -Status ParseOutputType(const std::string &output_type, std::map> &out_type_index_map, - std::map> &out_type_dt_map) { - if (output_type.find(':') == std::string::npos) { - GELOGI("output_type is not multiple nodes, means all out nodes"); - auto it = output_type_str_to_datatype.find(output_type); - if (it == output_type_str_to_datatype.end()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--output_type", output_type, kOutputTypeSupport}); - GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); - return domi::FAILED; - } - return domi::SUCCESS; - } - std::vector out_type_vec; - vector nodes_v = StringUtils::Split(output_type, ';'); - for (const string &node : nodes_v) { - vector node_index_type_v = StringUtils::Split(node, ':'); - if (node_index_type_v.size() != 3) { // The size must be 3. - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--output_type", node, kOutputTypeSample}); - GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample); - return domi::FAILED; - } - ge::DataType tmp_dt; - std::string node_name = StringUtils::Trim(node_index_type_v[0]); - std::string index_str = StringUtils::Trim(node_index_type_v[1]); - int32_t index; - if (StringToInt(index_str, index) != SUCCESS) { - GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); - return domi::FAILED; - } - std::string dt_value = StringUtils::Trim(node_index_type_v[2]); - auto it = output_type_str_to_datatype.find(dt_value); - if (it == output_type_str_to_datatype.end()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--output_type", dt_value, kOutputTypeSupport}); - GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport); - return domi::FAILED; - } else { - tmp_dt = it->second; - } - out_type_vec.push_back(node_name + ":" + index_str); - auto it_index = out_type_index_map.find(node_name); - if (it_index == out_type_index_map.end()) { - vector tmp_vec; - tmp_vec.push_back(index); - out_type_index_map.emplace(node_name, tmp_vec); - } else { - it_index->second.push_back(index); - } - - auto it_dt = out_type_dt_map.find(node_name); - if (it_dt == out_type_dt_map.end()) { - vector tmp_vec; - tmp_vec.push_back(tmp_dt); - out_type_dt_map.emplace(node_name, tmp_vec); - } else { - it_dt->second.push_back(tmp_dt); - } - } - return VerifyOutputTypeAndOutNodes(out_type_vec); -} - -Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { - int32_t out_size = op_desc->GetOutputsSize(); - if (index < 0 || index >= out_size) { - GELOGE(domi::FAILED, - "out_node [%s] output index:%d must be smaller " - "than node output size:%d and can not be negative!", - op_desc->GetName().c_str(), index, out_size); - std::string fail_reason = "output index:" + to_string(index) + - " must be smaller than output size:" + to_string(out_size) + " and can not be negative!"; - ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, - {"out_nodes", op_desc->GetName(), fail_reason}); - return domi::FAILED; - } - return domi::SUCCESS; -} - -Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info) { - ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); - if (tmpDescPtr == nullptr) { - GELOGE(domi::FAILED, "Get outnode op desc fail."); - return domi::FAILED; - } - size_t size = tmpDescPtr->GetOutputsSize(); - if (node->GetType() != NETOUTPUT) { - for (size_t index = 0; index < size; ++index) { - output_nodes_info.push_back(std::make_pair(node, index)); - } - } else { - const auto in_anchors = node->GetAllInDataAnchors(); - for (auto in_anchor : in_anchors) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - GELOGE(domi::FAILED, "Get leaf node op desc fail."); - return domi::FAILED; - } - auto out_node = out_anchor->GetOwnerNode(); - output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); - } - } - return SUCCESS; -} - -void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, - std::vector &output_nodes_name) { - output_nodes_name.clear(); - if (domi::GetContext().out_top_names.empty()) { - // tf process, no top name. - for (const auto output_node_info : output_nodes_info) { - std::string node_name = output_node_info.first->GetName(); - int32_t index = output_node_info.second; - output_nodes_name.push_back(node_name + ":" + std::to_string(index)); - } - return; - } - // caffe process, need add top name after node_name:index - for (size_t i = 0; i < output_nodes_info.size(); ++i) { - std::string node_name = output_nodes_info[i].first->GetName(); - int32_t index = output_nodes_info[i].second; - if (i < domi::GetContext().out_top_names.size()) { - output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + domi::GetContext().out_top_names[i]); - } else { - GELOGW("Get top name of node [%s] fail.", node_name.c_str()); - output_nodes_name.push_back(node_name + ":" + std::to_string(index)); - } - } -} -} // namespace - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { - if (is_output_fp16.empty()) { - return SUCCESS; - } - - vector &output_formats = domi::GetContext().output_formats; - output_formats.clear(); - vector node_format_vec = StringUtils::Split(is_output_fp16, ','); - for (auto &is_fp16 : node_format_vec) { - StringUtils::Trim(is_fp16); - if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) { - GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]", - is_output_fp16.c_str()); - return PARAM_INVALID; - } - if (is_fp16 == "false") { - output_formats.push_back(DOMI_TENSOR_ND); - } else if (is_fp16 == "true") { - output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0); - } - } - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, - const std::string &output_type, - const std::string &output) { - ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); - GE_CHECK_NOTNULL(compute_graph); - - std::vector> user_out_nodes = domi::GetContext().user_out_nodes; - std::vector output_formats = domi::GetContext().output_formats; - std::vector> output_nodes_info; - std::vector output_nodes_name; - std::map> out_type_index_map; - std::map> out_type_dt_map; - if (!output_type.empty()) { - if (ParseOutputType(output_type, out_type_index_map, out_type_dt_map) != SUCCESS) { - GELOGE(domi::FAILED, "Parse output_type failed."); - return domi::FAILED; - } - } - - // User declared outputs - for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { - ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first); - if (out_node == nullptr) { - GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str()); - return domi::FAILED; - } - auto op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) { - GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str()); - return domi::FAILED; - } - if (i < output_formats.size()) { - if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { - GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); - if (!ge::AttrUtils::SetBool(op_desc, "output_set_fp16_nc1hwc0", true)) { - GELOGW("The output node [%s] set NC1HWC0 failed", user_out_nodes[i].first.c_str()); - } - } - } - auto it_index = out_type_index_map.find(user_out_nodes[i].first); - auto it_dt = out_type_dt_map.find(user_out_nodes[i].first); - if ((it_index != out_type_index_map.end()) && (it_dt != out_type_dt_map.end())) { - GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str()); - (void)ge::AttrUtils::SetListDataType(op_desc, "_output_dt_list", it_dt->second); - (void)ge::AttrUtils::SetListInt(op_desc, "_output_dt_index", it_index->second); - } - output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); - } - // default output node (leaf) - if (user_out_nodes.empty()) { - for (ge::NodePtr node : compute_graph->GetDirectNode()) { - if (!node->GetInDataNodes().empty() && node->GetOutDataNodes().empty()) { - Status ret = GetOutputLeaf(node, output_nodes_info); - GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail."); - } - } - } - GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); - compute_graph->SetGraphOutNodesInfo(output_nodes_info); - domi::GetContext().net_out_nodes = output_nodes_name; - return domi::SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ParseInputShape( - const string &input_shape, unordered_map> &shape_map, - vector>> &user_shape_map, bool is_dynamic_input) { - vector shape_vec = StringUtils::Split(input_shape, ';'); - const int DEFAULT_SHAPE_PAIR_SIZE = 2; - for (const auto &shape : shape_vec) { - vector shape_pair_vec = SplitInputShape(shape); - if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { - ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, - {shape, kSplitError1, kInputShapeSample1}); - GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", - shape.c_str(), kSplitError1, kInputShapeSample1); - return false; - } - if (shape_pair_vec[1].empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, - {shape, kEmptyError, kInputShapeSample1}); - GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", - shape.c_str(), kEmptyError, kInputShapeSample1); - return false; - } - - vector shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); - vector shape_values; - for (auto &shape_value_str : shape_value_strs) { - // stoul: The method may throw an exception: invalid_argument/out_of_range - if (std::string::npos != shape_value_str.find('.')) { - ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, - {shape, kFloatNumError, kInputShapeSample2}); - GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", - shape.c_str(), kFloatNumError, kInputShapeSample2); - return false; - } - - long left_result = 0; - try { - left_result = stol(StringUtils::Trim(shape_value_str)); - if (!shape_value_str.empty() && (shape_value_str.front() == '-')) { - // The value maybe dynamic shape [-1], need substr it and verify isdigit. - shape_value_str = shape_value_str.substr(1); - } - for (char c : shape_value_str) { - if (!isdigit(c)) { - ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, - {shape, kDigitError, kInputShapeSample2}); - GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str()); - return false; - } - } - } catch (const std::out_of_range &) { - ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, - {"input_shape", shape_value_str}); - GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str()); - return false; - } catch (const std::invalid_argument &) { - ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, - {"input_shape", shape_value_str}); - GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str()); - return false; - } catch (...) { - ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, - {"input_shape", shape_value_str}); - GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str()); - return false; - } - int64_t result = left_result; - // - 1 is not currently supported - if (!is_dynamic_input && result <= 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)}); - GELOGW( - "Input parameter[--input_shape]’s shape value[%s] is invalid, " - "expect positive integer, but value is %ld.", - shape.c_str(), result); - return false; - } - shape_values.push_back(result); - } - - shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); - user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); - } - - return true; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputNodes(const string &out_nodes) { - try { - // parse output node - if (!out_nodes.empty()) { - domi::GetContext().out_nodes_map.clear(); - domi::GetContext().user_out_nodes.clear(); - - vector nodes_v = StringUtils::Split(out_nodes, ';'); - for (const string &node : nodes_v) { - vector key_value_v = StringUtils::Split(node, ':'); - if (key_value_v.size() != 2) { // The size must be 2. - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, - {"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""}); - GELOGE(PARAM_INVALID, - "The input format of --out_nodes is invalid, the correct format is " - "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", - node.c_str()); - return PARAM_INVALID; - } - auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); - // stoi: The method may throw an exception: invalid_argument/out_of_range - if (!CheckDigitStr(key_value_v[1])) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--out_nodes", out_nodes, "is not positive integer"}); - GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str()); - return PARAM_INVALID; - } - int32_t index = stoi(StringUtils::Trim(key_value_v[1])); - if (iter != domi::GetContext().out_nodes_map.end()) { - iter->second.emplace_back(index); - } else { - std::vector index_v; - index_v.emplace_back(index); - domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v); - } - domi::GetContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); - } - } - } catch (std::invalid_argument &) { - GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes}); - return PARAM_INVALID; - } catch (std::out_of_range &) { - GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes}); - return PARAM_INVALID; - } - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOpConf(const char *op_conf) { - if (op_conf != nullptr && *op_conf != '\0') { - // divided by ":" - PropertiesManager::Instance().SetPropertyDelimiter(OP_CONF_DELIMITER); - // Parsing the op_conf configuration item file - if (!PropertiesManager::Instance().Init(op_conf)) { - GELOGE(FAILED, "op_name_map init failed!"); - return FAILED; - } - // Return map and put it into ATC global variable - domi::GetContext().op_conf_map = PropertiesManager::Instance().GetPropertyMap(); - } - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/common/model_parser/graph_parser_util.h b/src/ge/common/model_parser/graph_parser_util.h deleted file mode 100644 index b38642c2..00000000 --- a/src/ge/common/model_parser/graph_parser_util.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_COMMON_GRAPH_PARSER_UTIL_H_ -#define GE_COMMON_GRAPH_PARSER_UTIL_H_ - -#include -#include -#include -#include -#include "framework/common/types.h" -#include "framework/omg/omg_inner_types.h" -#include "proto/ge_ir.pb.h" -#include "proto/om.pb.h" - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/model.h" -#include "runtime/kernel.h" - -using domi::Status; -using std::pair; -using std::string; -using std::unordered_map; -using std::vector; - -namespace ge { -Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); - -Status ParseOutputFp16NodesFormat(const string &is_output_fp16); - -Status ParseOutputNodes(const string &out_nodes); - -bool ParseInputShape(const string &input_shape, unordered_map> &shape_map, - vector>> &user_shape_map, bool is_dynamic_input); - -Status ParseOpConf(const char *op_conf); -} // namespace ge - -namespace domi { -/** - * @ingroup domi_omg - * @brief get omg context - * @return reference of OmgContext - */ -ge::OmgContext &GetContext(); -} // namespace domi - -#endif // GE_COMMON_GRAPH_PARSER_UTIL_H_ diff --git a/src/ge/common/profiling/profiling_manager.cc b/src/ge/common/profiling/profiling_manager.cc index 364f8298..503d52a1 100644 --- a/src/ge/common/profiling/profiling_manager.cc +++ b/src/ge/common/profiling/profiling_manager.cc @@ -76,8 +76,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In for (size_t i = 0; i < device_id_.size(); ++i) { ret = StartProfiling(0, device_id_[i]); if (ret != SUCCESS) { - GELOGE(ret, "Profiling start failed on device %d.", device_id_[i]); - return FAILED; + GELOGW("Profiling start failed on device %d.", device_id_[i]); + continue; } GELOGI("Profiling init succ on device %d.", device_id_[i]); } @@ -316,7 +316,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::St ProfMgrCfg prof_cfg = {send_profiling_config_}; void *prof_handle = ProfMgrStartUp(&prof_cfg); if (prof_handle == nullptr) { - GELOGW("ProfMgrStartUp failed."); + GELOGW("ProfMgrStartUp failed on device %d ", device_id); return FAILED; } GELOGD("StartProfiling, prof_handle: %p", prof_handle); diff --git a/src/ge/common/properties_manager.cc b/src/ge/common/properties_manager.cc index 0c2b1db6..2e2405e7 100644 --- a/src/ge/common/properties_manager.cc +++ b/src/ge/common/properties_manager.cc @@ -31,193 +31,6 @@ #include "graph/utils/attr_utils.h" namespace ge { -namespace { -const string kEnableFlag = "1"; - -const uint32_t kAicoreOverflow = (0x1 << 0); -const uint32_t kAtomicOverflow = (0x1 << 1); -const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); -} // namespace - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { - CopyFrom(other); -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=( - const DumpProperties &other) { - CopyFrom(other); - return *this; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() { - enable_dump_.clear(); - enable_dump_debug_.clear(); - dump_path_.clear(); - dump_step_.clear(); - dump_mode_.clear(); - is_op_debug_ = false; - op_debug_mode_ = 0; - - string enable_dump; - (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); - enable_dump_ = enable_dump; - - string enable_dump_debug; - (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); - enable_dump_debug_ = enable_dump_debug; - - if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { - string dump_path; - if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { - if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { - dump_path = dump_path + "/"; - } - dump_path = dump_path + CurrentTimeInStr() + "/"; - GELOGI("Get dump path %s successfully", dump_path.c_str()); - SetDumpPath(dump_path); - } else { - GELOGW("DUMP_PATH is not set"); - } - } - - if (enable_dump_ == kEnableFlag) { - string dump_step; - if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { - GELOGD("Get dump step %s successfully", dump_step.c_str()); - SetDumpStep(dump_step); - } - string dump_mode; - if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { - GELOGD("Get dump mode %s successfully", dump_mode.c_str()); - SetDumpMode(dump_mode); - } - AddPropertyValue(DUMP_ALL_MODEL, {}); - } - - SetDumpDebugOptions(); -} - -// The following is the new dump scenario of the fusion operator -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue( - const std::string &model, const std::set &layers) { - for (const std::string &layer : layers) { - GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); - } - - model_dump_properties_map_[model] = layers; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) { - auto iter = model_dump_properties_map_.find(model); - if (iter != model_dump_properties_map_.end()) { - model_dump_properties_map_.erase(iter); - } -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetAllDumpModel() const { - std::set model_list; - for (auto &iter : model_dump_properties_map_) { - model_list.insert(iter.first); - } - - return model_list; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetPropertyValue( - const std::string &model) const { - auto iter = model_dump_properties_map_.find(model); - if (iter != model_dump_properties_map_.end()) { - return iter->second; - } - return {}; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump( - const std::string &model, const std::string &om_name, const std::string &op_name) const { - // if dump all - if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { - return true; - } - - // if this model need dump - auto om_name_iter = model_dump_properties_map_.find(om_name); - auto model_name_iter = model_dump_properties_map_.find(model); - if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { - // if no dump layer info, dump all layer in this model - auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; - if (model_iter->second.empty()) { - return true; - } - - return model_iter->second.find(op_name) != model_iter->second.end(); - } - - GELOGD("Model %s is not seated to be dump.", model.c_str()); - return false; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) { - dump_path_ = path; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpPath() const { return dump_path_; } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) { - dump_step_ = step; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpStep() const { return dump_step_; } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) { - dump_mode_ = mode; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpMode() const { return dump_mode_; } - -void DumpProperties::CopyFrom(const DumpProperties &other) { - if (&other != this) { - enable_dump_ = other.enable_dump_; - enable_dump_debug_ = other.enable_dump_debug_; - dump_path_ = other.dump_path_; - dump_step_ = other.dump_step_; - dump_mode_ = other.dump_mode_; - - model_dump_properties_map_ = other.model_dump_properties_map_; - is_op_debug_ = other.is_op_debug_; - op_debug_mode_ = other.op_debug_mode_; - } -} - -void DumpProperties::SetDumpDebugOptions() { - if (enable_dump_debug_ == kEnableFlag) { - string dump_debug_mode; - if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { - GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); - } else { - GELOGW("Dump debug mode is not set."); - return; - } - - if (dump_debug_mode == OP_DEBUG_AICORE) { - GELOGD("ge.exec.dumpDebugMode=aicore_overflow, op debug is open."); - is_op_debug_ = true; - op_debug_mode_ = kAicoreOverflow; - } else if (dump_debug_mode == OP_DEBUG_ATOMIC) { - GELOGD("ge.exec.dumpDebugMode=atomic_overflow, op debug is open."); - is_op_debug_ = true; - op_debug_mode_ = kAtomicOverflow; - } else if (dump_debug_mode == OP_DEBUG_ALL) { - GELOGD("ge.exec.dumpDebugMode=all, op debug is open."); - is_op_debug_ = true; - op_debug_mode_ = kAllOverflow; - } else { - GELOGW("ge.exec.dumpDebugMode is invalid."); - } - } else { - GELOGI("ge.exec.enableDumpDebug is false or is not set."); - } -} - PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {} PropertiesManager::~PropertiesManager() {} diff --git a/src/ge/common/properties_manager.h b/src/ge/common/properties_manager.h index 3b1547f5..e4e84f74 100644 --- a/src/ge/common/properties_manager.h +++ b/src/ge/common/properties_manager.h @@ -24,6 +24,7 @@ #include #include "graph/op_desc.h" +#include "common/dump/dump_properties.h" namespace ge { // Configuration property management @@ -32,50 +33,6 @@ static const char *USE_FUSION __attribute__((unused)) = "FMK_USE_FUSION"; static const char *TIMESTAT_ENABLE __attribute__((unused)) = "DAVINCI_TIMESTAT_ENABLE"; static const char *ANNDROID_DEBUG __attribute__((unused)) = "ANNDROID_DEBUG"; -class DumpProperties { - public: - DumpProperties() = default; - ~DumpProperties() = default; - DumpProperties(const DumpProperties &dump); - DumpProperties &operator=(const DumpProperties &dump); - - void InitByOptions(); - - void AddPropertyValue(const std::string &model, const std::set &layers); - void DeletePropertyValue(const std::string &model); - - std::set GetAllDumpModel() const; - std::set GetPropertyValue(const std::string &model) const; - bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name) const; - - void SetDumpPath(const std::string &path); - std::string GetDumpPath() const; - - void SetDumpStep(const std::string &step); - std::string GetDumpStep() const; - - void SetDumpMode(const std::string &mode); - std::string GetDumpMode() const; - - bool IsOpDebugOpen() const { return is_op_debug_; } - uint32_t GetOpDebugMode() const { return op_debug_mode_; } - - private: - void CopyFrom(const DumpProperties &other); - void SetDumpDebugOptions(); - - string enable_dump_; - string enable_dump_debug_; - - std::string dump_path_; - std::string dump_step_; - std::string dump_mode_; - std::map> model_dump_properties_map_; - - bool is_op_debug_ = false; - uint32_t op_debug_mode_ = 0; -}; - class PropertiesManager { public: // Singleton diff --git a/src/ge/common/types.cc b/src/ge/common/types.cc index 2de75ff6..de293d34 100644 --- a/src/ge/common/types.cc +++ b/src/ge/common/types.cc @@ -502,6 +502,7 @@ const uint32_t MODEL_FILE_HEAD_LEN = 256; /// @brief Input node type /// const std::string INPUT_TYPE = "Input"; +const std::string DUMMY_DATA = "DummyData"; /// /// @ingroup domi_omg diff --git a/src/ge/common/util.cc b/src/ge/common/util.cc index a52978af..dca50627 100644 --- a/src/ge/common/util.cc +++ b/src/ge/common/util.cc @@ -57,7 +57,7 @@ const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 const int kMaxFileSizeLimit = INT_MAX; const int kMaxBuffSize = 256; -const char *const kPathValidReason = "The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; +const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; } // namespace namespace ge { @@ -311,6 +311,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestap() { return static_cast(total_use_time); } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() { + struct timeval tv {}; + int ret = gettimeofday(&tv, nullptr); + GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret); + auto total_use_time = tv.tv_sec; // seconds + return static_cast(total_use_time); +} + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int64_t a, int64_t b) { if (a > 0) { if (b > 0) { @@ -372,10 +380,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const } // A regular matching expression to verify the validity of the input file path - // ^(/|./|(../)+|)([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+/)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$ // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) - std::string mode = "^(/+|./+|(../+)+|)(../|([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+)/+)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$"; + std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( !ValidateStr(real_path, mode), @@ -408,10 +415,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), PATH_MAX); // A regular matching expression to verify the validity of the input file path - // ^(/|./|(../)+|)([.]?[\u4e00-\u9fa5A-Za-z0-9_-]+/)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$ // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) - std::string mode = "^(/+|./+|(../+)+|)(../|([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+)/+)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$"; + std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( !ValidateStr(file_path, mode), @@ -460,9 +466,9 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::str int ret = regcomp(®, mode.c_str(), cflags); if (ret) { regerror(ret, ®, ebuff, kMaxBuffSize); - GELOGE(ge::PARAM_INVALID, "regcomp failed, reason: %s", ebuff); + GELOGW("regcomp failed, reason: %s", ebuff); regfree(®); - return false; + return true; } ret = regexec(®, str.c_str(), 0, nullptr, 0); diff --git a/src/ge/engine_manager/dnnengine_manager.cc b/src/ge/engine_manager/dnnengine_manager.cc index ad36ebb5..fe3c1bc8 100644 --- a/src/ge/engine_manager/dnnengine_manager.cc +++ b/src/ge/engine_manager/dnnengine_manager.cc @@ -42,6 +42,8 @@ const char *const kVectorCore = "VectorCore"; const char *const kVectorEngine = "VectorEngine"; const char *const kAIcoreEngine = "AIcoreEngine"; const char *const kCustomOpFlag = "_custom_op_flag"; +const char *const kHostCpuEngineName = "DNN_VM_HOST_CPU"; +const char *const kHostCpuOpKernelLibName = "DNN_VM_HOST_CPU_OP_STORE"; } // namespace namespace ge { @@ -181,6 +183,7 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { GELOGI("DNNEngineManager: Can not get op info by op type %s", op_desc->GetType().c_str()); return ""; } + GE_IF_BOOL_EXEC(ge::GetContext().GetHostExecFlag(), return GetHostCpuEngineName(op_infos, op_desc)); std::string ge_core_type; Status ret = ge::GetContext().GetOption(ge::CORE_TYPE, ge_core_type); GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGD("get the option CORE_TYPE fail, set it to default value VECTOR_ENGINE")); @@ -245,6 +248,22 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { return ""; } +std::string DNNEngineManager::GetHostCpuEngineName(const std::vector &op_infos, + const OpDescPtr &op_desc) const { + for (const auto &it : op_infos) { + if ((it.engine == kHostCpuEngineName) && (it.opKernelLib == kHostCpuOpKernelLibName)) { + op_desc->SetOpEngineName(kHostCpuEngineName); + op_desc->SetOpKernelLibName(kHostCpuOpKernelLibName); + GELOGI("DNNEngineManager: Set OpKernelLibName %s and OpEngineName %s to %s", kHostCpuOpKernelLibName, + kHostCpuEngineName, op_desc->GetName().c_str()); + return kHostCpuEngineName; + } + } + GELOGE(FAILED, "DNNEngineManager: HostCpuEngine not support [%s, %s].", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + return ""; +} + const std::map &DNNEngineManager::GetSchedulers() const { return schedulers_; } Status DNNEngineManager::ParserJsonFile() { diff --git a/src/ge/engine_manager/dnnengine_manager.h b/src/ge/engine_manager/dnnengine_manager.h index 15628ecf..6d5b02f9 100644 --- a/src/ge/engine_manager/dnnengine_manager.h +++ b/src/ge/engine_manager/dnnengine_manager.h @@ -76,6 +76,7 @@ class DNNEngineManager { Status ParserEngineMessage(const json engines_json, const string &scheduler_mark, map &engines); Status CheckJsonFile(); + std::string GetHostCpuEngineName(const std::vector &op_infos, const OpDescPtr &op_desc) const; PluginManager plugin_mgr_; std::map engines_map_; std::map engines_attrs_map_; diff --git a/src/ge/engine_manager/engine_conf.json b/src/ge/engine_manager/engine_conf.json index 8c8990ee..82360562 100755 --- a/src/ge/engine_manager/engine_conf.json +++ b/src/ge/engine_manager/engine_conf.json @@ -5,6 +5,13 @@ "name": "1980_hwts", "ex_attrs": "", "cal_engines": [ + { + "id": "DNN_VM_HOST_CPU", + "name": "HOST_CPU", + "independent": false, + "skip_assign_stream": true, + "attach": true + }, { "id": "DNN_VM_GE_LOCAL", "name": "GE_LOCAL", diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt index 1b0b8131..17508711 100755 --- a/src/ge/executor/CMakeLists.txt +++ b/src/ge/executor/CMakeLists.txt @@ -26,6 +26,9 @@ file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "ge_executor.cc" + "../common/dump/dump_properties.cc" + "../common/dump/dump_manager.cc" + "../common/dump/dump_op.cc" "../common/ge/op_tiling_manager.cc" "../common/ge/plugin_manager.cc" "../common/profiling/profiling_manager.cc" diff --git a/src/ge/executor/ge_executor.cc b/src/ge/executor/ge_executor.cc index ee65faec..0d334042 100644 --- a/src/ge/executor/ge_executor.cc +++ b/src/ge/executor/ge_executor.cc @@ -23,6 +23,7 @@ #include "common/ge/ge_util.h" #include "common/helper/model_helper.h" #include "common/profiling/profiling_manager.h" +#include "common/dump/dump_manager.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" @@ -35,6 +36,8 @@ #include "graph/utils/graph_utils.h" #include "mmpa/mmpa_api.h" #include "single_op/single_op_manager.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/load/new_model_manager/davinci_model.h" using std::string; using std::vector; @@ -348,18 +351,46 @@ Status GeExecutor::SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, u } vector cur_dynamic_dims; - if (GetCurDynamicDims(model_id, dynamic_dims, cur_dynamic_dims) != SUCCESS) { - GELOGE(FAILED, "GetCurDynamicDims failed."); + std::vector input_desc; + std::vector output_desc; + ret = GetModelDescInfo(model_id, input_desc, output_desc); + if (ret != ge::SUCCESS) { + GELOGE(FAILED, "GetModelDescInfo failed."); return FAILED; } - + vector user_designate_shape_order; + vector all_data_dims; + ret = GetUserDesignateShapeOrder(model_id, user_designate_shape_order); + if (ret != ge::SUCCESS) { + GELOGE(FAILED, "GetUserDesignateShapeOrder failed."); + return FAILED; + } + for (auto &data_name : user_designate_shape_order) { + for (size_t j = 0; j < input_desc.size(); ++j) { + if (input_desc.at(j).GetName() == data_name) { + for (auto dim : input_desc.at(j).GetShape().GetDims()) { + all_data_dims.push_back(dim); + } + break; + } + } + } + if (dynamic_dims.size() != all_data_dims.size()) { + GELOGE(FAILED, "Dynamic input size [%lu] is not equal with all data dims size [%lu]!", dynamic_dims.size(), + all_data_dims.size()); + return FAILED; + } + for (std::size_t i = 0; i < all_data_dims.size(); ++i) { + if (all_data_dims[i] < 0) { + cur_dynamic_dims.push_back(dynamic_dims[i]); + } + } size_t dynamic_dim_num = cur_dynamic_dims.size(); uint64_t dynamic_input_size = static_cast(dynamic_dim_num * sizeof(uint64_t)); if (length < dynamic_input_size) { GELOGE(FAILED, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); return FAILED; } - for (uint32_t i = 0; i < dynamic_dim_num; ++i) { // Memcpy dynamic dim[i] from host to device if (rtMemcpy(reinterpret_cast(reinterpret_cast(dynamic_input_addr) + sizeof(uint64_t) * i), @@ -549,6 +580,12 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { GELOGE(ret, "[GraphLoader] DestroyAicpuSessionForInfer failed. model id: %u", model_id); return FAILED; } + + std::shared_ptr davinci_model = ModelManager::GetInstance()->GetModel(model_id); + if (davinci_model != nullptr) { + uint64_t session_id = davinci_model->GetSessionId(); + VarManagerPool::Instance().RemoveVarManager(session_id); + } return GraphLoader::UnloadModel(model_id); } @@ -658,6 +695,30 @@ Status GeExecutor::GetCombinedDynamicDims(uint32_t model_id, vector &user_designate_shape_order) { + GELOGI("Begin to get user designate shape info."); + if (!isInit_) { + GELOGE(GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + return GE_EXEC_NOT_INIT; + } + + Status ret = GraphExecutor::GetUserDesignateShapeOrder(model_id, user_designate_shape_order); + if (ret != SUCCESS) { + GELOGE(ret, "GetUserDesignateShapeOrder failed."); + return ret; + } + + GELOGI("Get user designate shape order succ."); + return SUCCESS; +} + /// /// @ingroup ge /// @brief Get AIPP input format @@ -674,7 +735,7 @@ Status GeExecutor::GetAIPPInfo(uint32_t model_id, uint32_t index, AippConfigInfo } Status ret = GraphExecutor::GetAIPPInfo(model_id, index, aipp_info); if (ret != SUCCESS) { - GELOGE(ret, "GetAIPPInfo failed."); + GELOGW("GetAIPPInfo is not success."); return ret; } GELOGI("GetAIPPInfo succ."); @@ -1020,4 +1081,26 @@ Status GeExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, GELOGI("GetAllAippInputOutputDims succ."); return SUCCESS; } + +Status GeExecutor::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) { + GELOGI("Begin to GetOpDescInfo."); + Status ret = GraphExecutor::GetOpDescInfo(device_id, stream_id, task_id, op_desc_info); + if (ret != SUCCESS) { + GELOGE(ret, "GetOpDescInfo failed."); + return ret; + } + GELOGI("GetOpDescInfo succ."); + return SUCCESS; +} + +Status GeExecutor::SetDump(const DumpConfig &dump_config) { + GELOGI("Start to set dump config"); + auto ret = DumpManager::GetInstance().SetDumpConf(dump_config); + if (ret != SUCCESS) { + GELOGE(ret, "Set dump conf failed"); + return ret; + } + GELOGI("Set dump config succ."); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/executor/module.mk b/src/ge/executor/module.mk index b19f3c24..878341b6 100644 --- a/src/ge/executor/module.mk +++ b/src/ge/executor/module.mk @@ -3,6 +3,9 @@ LOCAL_PATH := $(call my-dir) local_ge_executor_src_files := \ ge_executor.cc \ ../common/profiling/profiling_manager.cc \ + ../common/dump/dump_properties.cc \ + ../common/dump/dump_manager.cc \ + ../common/dump/dump_op.cc \ ../common/ge/plugin_manager.cc \ ../common/ge/op_tiling_manager.cc \ ../graph/load/graph_loader.cc \ diff --git a/src/ge/ge_inference.mk b/src/ge/ge_inference.mk index e3e1e10c..0cc0d6fb 100644 --- a/src/ge/ge_inference.mk +++ b/src/ge/ge_inference.mk @@ -26,6 +26,9 @@ COMMON_LOCAL_SRC_FILES := \ common/formats/format_transfers/format_transfer_nchw_fz_c04.cc \ common/formats/formats.cc \ common/profiling/profiling_manager.cc \ + common/dump/dump_properties.cc \ + common/dump/dump_manager.cc \ + common/dump/dump_op.cc \ common/helper/model_cache_helper.cc \ ge_local_engine/engine/host_cpu_engine.cc \ @@ -42,6 +45,7 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ graph/manager/graph_manager_utils.cc \ graph/manager/graph_context.cc \ graph/preprocess/graph_preprocess.cc \ + graph/preprocess/multi_batch_options.cc \ graph/preprocess/multi_batch_copy_graph.cc \ graph/execute/graph_execute.cc \ graph/load/graph_loader.cc \ @@ -149,6 +153,7 @@ OMG_HOST_SRC_FILES := \ host_kernels/slice_kernel.cc \ host_kernels/slice_d_kernel.cc \ host_kernels/dynamic_stitch_kernel.cc \ + host_kernels/identity_kernel.cc \ graph/passes/stop_gradient_pass.cc \ graph/passes/prevent_gradient_pass.cc \ graph/passes/identity_pass.cc \ @@ -165,12 +170,16 @@ OMG_HOST_SRC_FILES := \ graph/passes/switch_to_stream_switch_pass.cc \ graph/passes/attach_stream_label_pass.cc \ graph/passes/multi_batch_pass.cc \ + graph/passes/multi_batch_clone_pass.cc \ + graph/passes/subexpression_migration_pass.cc \ + graph/passes/unused_args_clean_pass.cc \ graph/passes/next_iteration_pass.cc \ graph/passes/control_trigger_pass.cc \ graph/passes/cond_pass.cc \ graph/passes/cond_remove_pass.cc \ graph/passes/for_pass.cc \ graph/passes/enter_pass.cc \ + graph/passes/assign_pass.cc \ graph/passes/addn_pass.cc \ graph/passes/common_subexpression_elimination_pass.cc \ graph/passes/transop_symmetry_elimination_pass.cc \ @@ -185,11 +194,10 @@ OMG_HOST_SRC_FILES := \ graph/passes/transpose_transdata_pass.cc \ graph/passes/hccl_memcpy_pass.cc \ graph/passes/flow_ctrl_pass.cc \ + graph/passes/global_step_insert_pass.cc \ graph/passes/link_gen_mask_nodes_pass.cc \ graph/passes/replace_with_empty_const_pass.cc \ graph/passes/hccl_group_pass.cc \ - graph/passes/switch_fusion_pass.cc \ - graph/passes/switch_split_pass.cc \ graph/passes/memcpy_addr_async_pass.cc \ graph/passes/set_input_output_offset_pass.cc \ diff --git a/src/ge/ge_runner.mk b/src/ge/ge_runner.mk index a3119b50..66e2be5a 100644 --- a/src/ge/ge_runner.mk +++ b/src/ge/ge_runner.mk @@ -26,6 +26,9 @@ LIBGE_LOCAL_SRC_FILES := \ common/ge/op_tiling_manager.cc\ common/helper/model_cache_helper.cc \ common/profiling/profiling_manager.cc \ + common/dump/dump_manager.cc \ + common/dump/dump_properties.cc \ + common/dump/dump_op.cc \ engine_manager/dnnengine_manager.cc \ ge_local_engine/engine/host_cpu_engine.cc \ generator/ge_generator.cc \ @@ -93,7 +96,6 @@ LIBGE_LOCAL_SRC_FILES := \ graph/manager/util/variable_accelerate_ctrl.cc \ graph/optimize/graph_optimize.cc \ graph/optimize/mem_rw_conflict_optimize.cc \ - graph/optimize/optimizer/allreduce_fusion_pass.cc \ graph/optimize/summary_optimize.cc \ graph/partition/engine_place.cc \ graph/partition/graph_partition.cc \ @@ -119,10 +121,10 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/dimension_compute_pass.cc \ graph/passes/dropout_pass.cc \ graph/passes/hccl_group_pass.cc \ - graph/passes/switch_fusion_pass.cc \ - graph/passes/switch_split_pass.cc \ graph/passes/enter_pass.cc \ + graph/passes/assign_pass.cc \ graph/passes/flow_ctrl_pass.cc \ + graph/passes/global_step_insert_pass.cc \ host_kernels/transpose_kernel.cc \ host_kernels/add_kernel.cc \ host_kernels/broadcast_args_kernel.cc \ @@ -131,6 +133,7 @@ LIBGE_LOCAL_SRC_FILES := \ host_kernels/concat_offset_kernel.cc \ host_kernels/concat_v2_kernel.cc \ host_kernels/dynamic_stitch_kernel.cc \ + host_kernels/identity_kernel.cc \ host_kernels/empty_kernel.cc \ host_kernels/expanddims_kernel.cc \ host_kernels/fill_kernel.cc \ @@ -172,6 +175,9 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/link_gen_mask_nodes_pass.cc \ graph/passes/merge_pass.cc \ graph/passes/multi_batch_pass.cc \ + graph/passes/multi_batch_clone_pass.cc \ + graph/passes/subexpression_migration_pass.cc \ + graph/passes/unused_args_clean_pass.cc \ graph/passes/net_output_pass.cc \ graph/passes/next_iteration_pass.cc \ graph/passes/no_use_reshape_remove_pass.cc \ @@ -225,6 +231,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/preprocess/graph_preprocess.cc \ graph/preprocess/insert_op/ge_aipp_op.cc \ graph/preprocess/insert_op/util_insert_aipp_op.cc \ + graph/preprocess/multi_batch_options.cc \ graph/preprocess/multi_batch_copy_graph.cc \ init/gelib.cc \ model/ge_model.cc \ @@ -267,10 +274,17 @@ LIBGE_LOCAL_SRC_FILES := \ hybrid/node_executor/aicpu/aicpu_ext_info.cc \ hybrid/node_executor/aicpu/aicpu_node_executor.cc \ hybrid/node_executor/compiledsubgraph/known_node_executor.cc \ - hybrid/node_executor/hostcpu/ge_local_node_executor.cc \ + hybrid/node_executor/ge_local/ge_local_node_executor.cc \ + hybrid/node_executor/host_cpu/host_cpu_node_executor.cc \ + hybrid/node_executor/host_cpu/kernel_factory.cc \ + hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc \ + hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \ + hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \ + hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \ hybrid/node_executor/controlop/control_op_executor.cc \ hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ hybrid/node_executor/hccl/hccl_node_executor.cc \ + hybrid/node_executor/rts/rts_node_executor.cc \ hybrid/node_executor/node_executor.cc \ hybrid/node_executor/task_context.cc \ hybrid/hybrid_davinci_model.cc \ @@ -343,7 +357,6 @@ LOCAL_SHARED_LIBRARIES := \ libgraph \ libregister \ libge_common \ - libhccl \ libmsprof \ liberror_manager \ @@ -425,7 +438,6 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ libmmpa \ - libhccl \ libmsprof \ LOCAL_LDFLAGS := -lrt -ldl @@ -457,7 +469,6 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ libmmpa \ - libhccl \ libmsprof \ LOCAL_LDFLAGS := -lrt -ldl diff --git a/src/ge/generator/ge_generator.cc b/src/ge/generator/ge_generator.cc index bc1e78c1..0d4fac3f 100644 --- a/src/ge/generator/ge_generator.cc +++ b/src/ge/generator/ge_generator.cc @@ -658,10 +658,13 @@ Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector if (ret != SUCCESS) { GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", id); + VarManagerPool::Instance().RemoveVarManager(session_id); return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; } id += 1; + VarManagerPool::Instance().RemoveVarManager(session_id); + return SUCCESS; } diff --git a/src/ge/graph/build/graph_builder.cc b/src/ge/graph/build/graph_builder.cc index 51519023..ac83d4ec 100644 --- a/src/ge/graph/build/graph_builder.cc +++ b/src/ge/graph/build/graph_builder.cc @@ -28,6 +28,7 @@ #include "graph/common/ge_call_wrapper.h" #include "init/gelib.h" #include "model/ge_model.h" +#include "graph/ge_context.h" using domi::BuildMode; @@ -166,11 +167,15 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, GeModelPtr &ge_model_ptr, - uint64_t session_id) { +Status GraphBuilder::BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_list, + GeModelPtr &ge_model_ptr, uint64_t session_id) { + if (ge::GetContext().GetHostExecFlag()) { + GE_CHK_STATUS_RET(BuildForHostCpuGraph(comp_graph, ge_model_ptr, session_id), "Build for host-cpu graph failed."); + return SUCCESS; + } + GELOGI("Begin to build known shape graph[%s].", comp_graph->GetName().c_str()); - Status ret = SecondPartition(comp_graph, subgraph_ptr_list); + Status ret = SecondPartition(comp_graph, subgraph_list); GE_CHK_STATUS_RET(ret, "Graph[%s] second partition Failed.", comp_graph->GetName().c_str()); auto subgraph_map = graph_partitioner_.GetSubGraphMap(); @@ -257,6 +262,10 @@ Status GraphBuilder::BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeMo return SUCCESS; } +Status GraphBuilder::BuildForHostCpuGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, uint64_t session_id) { + return BuildForUnknownShapeGraph(comp_graph, ge_model_ptr, session_id); +} + Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, diff --git a/src/ge/graph/build/graph_builder.h b/src/ge/graph/build/graph_builder.h index def3a28b..dd229bc6 100644 --- a/src/ge/graph/build/graph_builder.h +++ b/src/ge/graph/build/graph_builder.h @@ -63,10 +63,12 @@ class GraphBuilder { Status BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); - Status BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, + Status BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, std::vector &subgraph_list, GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); Status BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); + Status BuildForHostCpuGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, + uint64_t session_id = INVALID_SESSION_ID); int build_mode_; std::map stream_max_parallel_num_; diff --git a/src/ge/graph/build/memory/block_mem_assigner.cc b/src/ge/graph/build/memory/block_mem_assigner.cc index 99b2fd7d..3d956230 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.cc +++ b/src/ge/graph/build/memory/block_mem_assigner.cc @@ -745,6 +745,23 @@ bool BlockMemAssigner::IsContinuousOutput(const NodePtr &n) { return false; } +bool BlockMemAssigner::IsZeroCopyBlock(const NodePtr &node, bool continuous) { + if (NodeUtils::IsDynamicShape(node)) { + return ((node->GetType() == DATA_TYPE) && !continuous) || (node->GetType() == NETOUTPUT); + } + + if ((node->GetType() == DATA_TYPE) && !continuous) { + return !node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); + } + + if (node->GetType() == NETOUTPUT) { + const auto &owner = node->GetOwnerComputeGraph(); + return owner->GetParentGraph() == nullptr; + } + + return false; +} + MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, MemoryType mem_type, const NodePtr &n, uint32_t out_index, const vector &workspace_reuse_flag, const bool is_op_reuse_mem, @@ -793,9 +810,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "new an object failed."); // Data and netoutput need zero copy block - if ((node_op_desc->GetType() == DATA_TYPE && !continuous) || (node_op_desc->GetType() == NETOUTPUT)) { - block->is_zero_copy_ = true; - } + block->is_zero_copy_ = IsZeroCopyBlock(n, continuous); block->Init(real_size, mem_type, n, out_index, no_align_size); block->stream_id_ = node_op_desc->GetStreamId(); @@ -970,6 +985,14 @@ bool IsAtomicOutputMemory(const ge::NodePtr &node, uint32_t output_index, bool i return false; } +bool IsKnownSubgraphData(const NodePtr &node) { + if (NodeUtils::IsDynamicShape(node)) { + return false; + } + + return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); +} + void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector &reusable_memory) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); @@ -1092,7 +1115,7 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic); // Allocate memory for the current node and release node memory of the same size in the workspace GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1", - ReleaseMemorys(stream_workspace_blocks_[stream_id], reusable_blocks_[stream_id]);) + ReleaseMemorys(stream_workspace_blocks_[stream_id], reusable_blocks_[stream_id])); if (IsContinuousOutput(node)) { (void)ApplyContinuousMemory(node, ranges, is_op_reuse_mem_); return SUCCESS; @@ -1118,6 +1141,7 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector out_node_set_continuous_input = IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index); no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input); } + no_need_assign_memory = (no_need_assign_memory || IsKnownSubgraphData(node)); if (no_need_assign_memory) { zero_memory_list_.emplace_back(node, kOutput, i, false); continue; @@ -1474,8 +1498,8 @@ void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, siz return; } - if ((op_desc->GetType() == DATA) || (op_desc->GetType() == AIPP_DATA_TYPE) || (op_desc->GetType() == MULTISHAPE) || - (op_desc->GetType() == NETOUTPUT)) { + static const set kSetOffsetTypes = {DATA_TYPE, AIPP_DATA_TYPE, MULTISHAPE, NETOUTPUT}; + if ((kSetOffsetTypes.count(op_desc->GetType()) > 0) && !IsKnownSubgraphData(node_type.node)) { if ((output_list[node_type.index] == kInvalidOffset) || (output_list[node_type.index] < offset)) { output_list.at(node_type.index) = offset; } diff --git a/src/ge/graph/build/memory/block_mem_assigner.h b/src/ge/graph/build/memory/block_mem_assigner.h index 3dfba4c5..eedc7bec 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.h +++ b/src/ge/graph/build/memory/block_mem_assigner.h @@ -352,6 +352,8 @@ class BlockMemAssigner : public MemAssigner { void AssignContinuousBlocks(); + bool IsZeroCopyBlock(const NodePtr &node, bool continuous); + bool IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name, uint32_t &peer_input_index); diff --git a/src/ge/graph/build/memory/graph_mem_assigner.cc b/src/ge/graph/build/memory/graph_mem_assigner.cc index c5060dbd..affa82c8 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.cc +++ b/src/ge/graph/build/memory/graph_mem_assigner.cc @@ -1227,6 +1227,18 @@ ge::Status GraphMemoryAssigner::SetInputOffset() { return ge::SUCCESS; } +NodePtr GraphMemoryAssigner::GetKnownInputNode(const NodePtr &node) const { + if (!node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) { + return node; + } + + if (NodeUtils::IsDynamicShape(node)) { + return node; + } + + return NodeUtils::GetParentInput(node); +} + ge::Status GraphMemoryAssigner::UpdateConstArgsOffset(const NodePtr &node, vector &input_list) const { uint32_t parent_index = 0; if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { @@ -1235,13 +1247,29 @@ ge::Status GraphMemoryAssigner::UpdateConstArgsOffset(const NodePtr &node, vecto // Subgraph Data Node, check for constant input. std::string op_type; - NodePtr in_node = NodeUtils::GetParentInput(node); - if (!NodeUtils::GetConstOpType(in_node, op_type)) { - return SUCCESS; // not constant input. + const auto &in_node = NodeUtils::GetParentInput(node); + if (NodeUtils::GetConstOpType(in_node, op_type)) { + input_list = in_node->GetOpDesc()->GetOutputOffset(); + node->GetOpDesc()->SetOutputOffset(input_list); // Set Data output same as const output. + return SUCCESS; // Constant input. + } + + // Memory allocated for dynamic shape subgraph Data. + if (NodeUtils::IsDynamicShape(node)) { + return SUCCESS; + } + + const auto &owner = node->GetOwnerComputeGraph(); + const auto &parent_desc = owner->GetParentNode()->GetOpDesc(); + const auto parent_inputs = parent_desc->GetInputOffset(); + if (parent_inputs.size() <= parent_index) { + GELOGE(FAILED, "Get Parent input offset failed, node: %s, input size: %zu, parent index: %u", + node->GetName().c_str(), parent_inputs.size(), parent_index); + return FAILED; } - vector const_input_list = in_node->GetOpDesc()->GetOutputOffset(); - node->GetOpDesc()->SetOutputOffset(const_input_list); // Set Data output same as const output. + input_list = {parent_inputs[parent_index]}; + node->GetOpDesc()->SetOutputOffset(input_list); // Set Data output same as parent input. return SUCCESS; } @@ -1287,7 +1315,8 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node, vector< input_list.back()); } else { int64_t output_offset = output_list.at(peer_out_anchor->GetIdx()); - if (peer_out_anchor->GetOwnerNode()->GetType() == CONSTANT) { + const auto &in_node = GetKnownInputNode(peer_out_anchor->GetOwnerNode()); + if (in_node->GetType() == CONSTANT) { GeTensorDesc tensor_desc = tmp_op_desc->GetInputDesc(input_index); GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, output_offset)); } diff --git a/src/ge/graph/build/memory/graph_mem_assigner.h b/src/ge/graph/build/memory/graph_mem_assigner.h index afe9a4fa..daec2f75 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.h +++ b/src/ge/graph/build/memory/graph_mem_assigner.h @@ -181,6 +181,8 @@ class GraphMemoryAssigner { ge::Status UpdateConstArgsOffset(const NodePtr &node, vector &input_list) const; + NodePtr GetKnownInputNode(const NodePtr &node) const; + MemoryOffsetList memory_offset_; ge::ComputeGraphPtr compute_graph_; HybridMemAssignerPtr mem_assigner_; diff --git a/src/ge/graph/build/model_builder.cc b/src/ge/graph/build/model_builder.cc index 5435eb7b..9a314d80 100644 --- a/src/ge/graph/build/model_builder.cc +++ b/src/ge/graph/build/model_builder.cc @@ -182,38 +182,26 @@ void ModelBuilder::SetInputIsConst(const ge::NodePtr &n) { for (size_t i = 0; i < is_input_const.size(); i++) { is_input_const[i] = false; } + + std::string const_type; auto in_data_anchors = n->GetAllInDataAnchors(); for (size_t index = 0; index < in_data_anchors.size(); index++) { auto in_data_anchor = in_data_anchors.at(index); const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); const auto &src_node = peer_out_anchor->GetOwnerNode(); - if (src_node->GetType() == CONSTANT) { + if (!NodeUtils::GetConstOpType(src_node, const_type)) { + continue; + } + + if (const_type == CONSTANT) { if (!SetInputConst(node_op_desc, src_node, index, is_input_const)) { return; } - } else if (src_node->GetType() == CONSTANTOP) { + } else { if ((index < is_input_const.size()) && is_input_const[index]) { is_input_const[index] = false; } - } else if (src_node->GetType() == DATA) { - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - continue; - } - - // Subgraph Data Node, check for constant input. - std::string op_type; - const NodePtr in_node = NodeUtils::GetParentInput(src_node); - if (!NodeUtils::GetConstOpType(in_node, op_type)) { - continue; // not constant input. - } - - if (op_type == CONSTANT) { - if (!SetInputConst(node_op_desc, in_node, index, is_input_const)) { - return; - } - } } } diff --git a/src/ge/graph/build/stream_allocator.cc b/src/ge/graph/build/stream_allocator.cc index 5c82f461..b7643e47 100644 --- a/src/ge/graph/build/stream_allocator.cc +++ b/src/ge/graph/build/stream_allocator.cc @@ -16,6 +16,7 @@ #include "graph/build/stream_allocator.h" #include +#include #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/fmk_error_codes.h" @@ -374,8 +375,8 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const return SUCCESS; } - if ((cur_node->GetType() == ENTER) || (cur_node->GetType() == REFENTER)) { - GELOGD("No need to insert event after enter_node %s.", cur_node->GetName().c_str()); + if (((cur_node->GetType() == ENTER) || (cur_node->GetType() == REFENTER)) && (next_node->GetType() != STREAMACTIVE)) { + GELOGD("No need to insert event between %s and %s.", cur_node->GetName().c_str(), next_node->GetName().c_str()); return SUCCESS; } @@ -721,6 +722,7 @@ Status StreamAllocator::SplitStreams(vector> &split_streams) { GELOGE(FAILED, "SplitStreams:streamid(%ld) > last_stream_id(%ld)", stream_id, last_stream_id); return FAILED; } + bool is_stream_first_node = (stream_node_num_vec[stream_id] == 0); AddNodeNum(cur_node, stream_node_num_vec[stream_id]); stream_2_nodes_map[stream_id].push_back(cur_node); // The maximum number of tasks per stream. @@ -737,7 +739,7 @@ Status StreamAllocator::SplitStreams(vector> &split_streams) { stream_continuous_2_nodes_map[continuous_stream_label].push_back(cur_node); } // Split the stream if it exceeds the maximum number of nodes in the stream. - if (NeedSpiltNewStream(stream_node_num_vec[stream_id], max_node_num_one_stream, op_desc)) { + if (NeedSpiltNewStream(stream_node_num_vec[stream_id], max_node_num_one_stream, op_desc, is_stream_first_node)) { last_stream_id++; GELOGI( "stream_node_num_vec[%ld]= %ld > max_node_num_one_stream : %ld, " @@ -801,7 +803,11 @@ Status StreamAllocator::SplitStreams(vector> &split_streams) { } bool StreamAllocator::NeedSpiltNewStream(int64_t stream_node_num, int64_t max_node_num_one_stream, - const OpDescPtr &op_desc) const { + const OpDescPtr &op_desc, bool is_stream_first_node) const { + if (is_stream_first_node) { + GELOGD("First node of stream does not need to split new stream"); + return false; + } const set label_op_types({LABELSET, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX}); bool is_first_active_node = false; (void)AttrUtils::GetBool(op_desc, ATTR_NAME_SUBGRAPH_FIRST_ACTIVE, is_first_active_node); @@ -1019,6 +1025,18 @@ Status StreamAllocator::SetActiveStreamsForLoop() { loop_active_streams.emplace_back(static_cast(stream_id)); } } + map stream_id_to_last_node; + set streams_skip_iterator_event; + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { + int64_t stream_id = node->GetOpDesc()->GetStreamId(); + if (find(loop_active_streams.begin(), loop_active_streams.end(), stream_id) != loop_active_streams.end()) { + stream_id_to_last_node[stream_id] = node; + // last node in stream which has streamswitch or IF may be not execute, it will cause block if add event on them + if (node->GetOpDesc()->GetType() == STREAMSWITCH) { + streams_skip_iterator_event.insert(stream_id); + } + } + } // Set the stream that needs to be activated for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); @@ -1031,7 +1049,31 @@ Status StreamAllocator::SetActiveStreamsForLoop() { GELOGE(FAILED, "SetListInt failed."); return FAILED); for (const auto &stream_id : loop_active_streams) { - GELOGI("Active stream %u for node: %s", stream_id, node->GetName().c_str()); + GELOGI("Active stream %u for node: %s.", stream_id, node->GetName().c_str()); + } + + // In switch group optimze case, some data input branch may exec slowly. + // when condition input branch judge false and some switch has no false branch, + // In this condition, data branch has no synchronize point, + // it may cause some stream actived by iterator next step when this stream still alive. + // If above situation happen, active message will lose, cause process block in next iteration. + // In order to avoid this abnormal happen, + // add event between each last node and iterator active node in target active stream + GELOGI("there are %zu next iterator target streams has streamswitch node.", streams_skip_iterator_event.size()); + for (auto iter : stream_id_to_last_node) { + if (streams_skip_iterator_event.find(iter.first) != streams_skip_iterator_event.end()) { + GELOGI("skip stream %ld which has streamswitch node when add event to next iterator active node", + iter.first); + continue; + } + if (iter.second->GetOwnerComputeGraph()->GetParentGraph() != nullptr) { + GELOGI("skip stream %ld which last node in subgraph when add event to next iterator active node", + iter.first); + continue; + } + AddSendEventId(iter.second, event_num_); + AddRecvEventId(node, event_num_); + event_num_++; } break; @@ -1132,7 +1174,7 @@ Status StreamAllocator::InsertSyncEventNodes() { return status; } - GELOGI("Insert recv event %u before node: %s", event_id, node->GetName().c_str()); + GELOGI("Insert recv event %u before node: %s.", event_id, node->GetName().c_str()); } // Add the node corresponding to the send event @@ -1160,7 +1202,7 @@ Status StreamAllocator::InsertSyncEventNodes() { return status; } - GELOGI("Insert send event %u after node: %s", event_id, node->GetName().c_str()); + GELOGI("Insert send event %u after node: %s.", event_id, node->GetName().c_str()); } } diff --git a/src/ge/graph/build/stream_allocator.h b/src/ge/graph/build/stream_allocator.h index a5326a39..0158e6b0 100644 --- a/src/ge/graph/build/stream_allocator.h +++ b/src/ge/graph/build/stream_allocator.h @@ -58,7 +58,8 @@ class StreamAllocator { bool IsActiveAfterNextIteration(const NodePtr &active_node_ptr) const; Status SplitStreams(std::vector> &split_streams); - bool NeedSpiltNewStream(int64_t stream_node_num, int64_t max_node_num_one_stream, const OpDescPtr &op_desc) const; + bool NeedSpiltNewStream(int64_t stream_node_num, int64_t max_node_num_one_stream, const OpDescPtr &op_desc, + bool is_stream_first_node) const; Status UpdateActiveStreams(const std::vector> &split_streams); void UpdateLabelStreams(const std::vector> &split_streams); diff --git a/src/ge/graph/build/task_generator.cc b/src/ge/graph/build/task_generator.cc index 41a845a2..91f70f2a 100644 --- a/src/ge/graph/build/task_generator.cc +++ b/src/ge/graph/build/task_generator.cc @@ -95,8 +95,8 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t GELOGE(FAILED, "SetListStr failed."); return FAILED); - GELOGI("Call GenerateTask Success, task_def_list.size:%zu, op_name_map.size:%zu", task_def_list.size(), - op_name_map.size()); + GELOGI("GenerateTask Success, task list:%zu, op map:%zu, logic mem base:%p, logic weight base:%p, logic var base:%p", + task_def_list.size(), op_name_map.size(), run_context.dataMemBase, run_context.weightMemBase, var_mem_base_); // Init and serialize model_task_def ModelTaskDef model_task_def; @@ -260,7 +260,7 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra int64_t group_key; uint32_t node_index = 0; rtStream_t stream = nullptr; - bool is_unknown_shape = graph->GetGraphUnknownFlag(); + bool is_unknown_shape = graph->GetGraphUnknownFlag() || GetContext().GetHostExecFlag(); if (is_unknown_shape) { GE_CHK_STATUS_RET(SetUnknownShapeStream(run_context, stream), "Set unknown shape stream failed."); } @@ -479,7 +479,12 @@ Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { GELOGE(INTERNAL_ERROR, "AnchorUtils::SetStatus failed."); return INTERNAL_ERROR; } - } else if (peer_anchor->GetOwnerNode()->GetType() == CONSTANT) { + continue; + } + + std::string const_type; + bool is_const = NodeUtils::GetConstOpType(peer_anchor->GetOwnerNode(), const_type); + if (is_const && (const_type == CONSTANT)) { if (AnchorUtils::SetStatus(anchor, ANCHOR_CONST) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "AnchorUtils::SetStatus failed."); return INTERNAL_ERROR; diff --git a/src/ge/graph/common/transop_util.cc b/src/ge/graph/common/transop_util.cc index 3250929d..eb80fb69 100644 --- a/src/ge/graph/common/transop_util.cc +++ b/src/ge/graph/common/transop_util.cc @@ -17,9 +17,13 @@ #include "graph/common/transop_util.h" #include "common/types.h" +#include "graph/utils/type_utils.h" +#include "framework/common/debug/ge_log.h" namespace { const int kInvalidTransopDataIndex = -1; +const int kTransOpOutIndex = 0; +std::map precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; } // namespace namespace ge { @@ -60,4 +64,20 @@ int TransOpUtil::GetTransOpDataIndex(const std::string &type) { } return kInvalidTransopDataIndex; } + +bool TransOpUtil::CheckPrecisionLoss(const ge::NodePtr &src_node) { + auto idx = TransOpUtil::GetTransOpDataIndex(src_node); + auto input_desc = src_node->GetOpDesc()->GetInputDesc(idx); + auto output_desc = src_node->GetOpDesc()->GetOutputDesc(kTransOpOutIndex); + auto src_dtype = input_desc.GetDataType(); + auto dst_dtype = output_desc.GetDataType(); + auto iter = precision_loss_transfer_map.find(src_dtype); + if (iter != precision_loss_transfer_map.end() && iter->second == dst_dtype) { + GELOGW("Node %s transfer data type from %s to %s ,it will cause precision loss. ignore pass.", + src_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(src_dtype).c_str(), + TypeUtils::DataTypeToSerialString(dst_dtype).c_str()); + return false; + } + return true; +} } // namespace ge diff --git a/src/ge/graph/common/transop_util.h b/src/ge/graph/common/transop_util.h index 041a7637..8b10ad5c 100644 --- a/src/ge/graph/common/transop_util.h +++ b/src/ge/graph/common/transop_util.h @@ -33,6 +33,8 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY TransOpUtil { static int GetTransOpDataIndex(const std::string &type); + static bool CheckPrecisionLoss(const NodePtr &src_node); + private: TransOpUtil(); diff --git a/src/ge/graph/execute/graph_execute.cc b/src/ge/graph/execute/graph_execute.cc index 1bebd382..25208aa4 100644 --- a/src/ge/graph/execute/graph_execute.cc +++ b/src/ge/graph/execute/graph_execute.cc @@ -519,6 +519,25 @@ Status GraphExecutor::GetCombinedDynamicDims(uint32_t model_id, std::vector &user_input_shape_order) { + auto model_manager = ge::ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + Status ret = model_manager->GetUserDesignateShapeOrder(model_id, user_input_shape_order); + if (ret != SUCCESS) { + GELOGE(ret, "GetUserDesignateShapeOrder failed."); + return ret; + } + return SUCCESS; +} + Status GraphExecutor::GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type) { auto model_manager = ge::ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); @@ -570,7 +589,7 @@ Status GraphExecutor::GetAIPPInfo(uint32_t model_id, uint32_t index, AippConfigI GE_CHECK_NOTNULL(model_manager); Status ret = model_manager->GetAIPPInfo(model_id, index, aipp_info); if (ret != SUCCESS) { - GELOGE(ret, "GetAIPPInfo failed."); + GELOGW("GetAIPPInfo is not success."); return ret; } @@ -602,4 +621,16 @@ Status GraphExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t inde return SUCCESS; } + +Status GraphExecutor::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, + OpDescInfo &op_desc_info) { + auto model_manager = ge::ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + Status ret = model_manager->GetOpDescInfo(device_id, stream_id, task_id, op_desc_info); + if (ret != SUCCESS) { + GELOGE(ret, "GetOpDescInfo failed."); + return ret; + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/execute/graph_execute.h b/src/ge/graph/execute/graph_execute.h index f79a2e29..5cf39bae 100644 --- a/src/ge/graph/execute/graph_execute.h +++ b/src/ge/graph/execute/graph_execute.h @@ -95,6 +95,15 @@ class GraphExecutor { /// static Status GetCombinedDynamicDims(uint32_t model_id, std::vector> &batch_info); + /// + /// @ingroup ge + /// @brief Get user designate shape order + /// @param [in] model_id + /// @param [out] user_input_shape_order + /// @return execute result + /// + static Status GetUserDesignateShapeOrder(uint32_t model_id, std::vector &user_input_shape_order); + static Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); static Status GetModelAttr(uint32_t model_id, std::vector &dynamic_output_shape_info); @@ -107,6 +116,8 @@ class GraphExecutor { static Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &input_dims, std::vector &output_dims); + static Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); + private: Status PrepareInputData(const std::vector &input_tensor, InputData &graph_input_data, OutputData &graph_output_data, std::vector &output_desc); diff --git a/src/ge/graph/load/new_model_manager/aipp_utils.cc b/src/ge/graph/load/new_model_manager/aipp_utils.cc index e7ae922c..0a348109 100644 --- a/src/ge/graph/load/new_model_manager/aipp_utils.cc +++ b/src/ge/graph/load/new_model_manager/aipp_utils.cc @@ -38,7 +38,9 @@ namespace ge { Status AippUtils::ConvertAippParams2AippInfo(domi::AippOpParams *aipp_params, AippConfigInfo &aipp_info) { GE_CHECK_NOTNULL(aipp_params); + AIPP_CONVERT_TO_AIPP_INFO(aipp_mode); AIPP_CONVERT_TO_AIPP_INFO(input_format); + AIPP_CONVERT_TO_AIPP_INFO(related_input_rank); AIPP_CONVERT_TO_AIPP_INFO(src_image_size_w); AIPP_CONVERT_TO_AIPP_INFO(src_image_size_h); AIPP_CONVERT_TO_AIPP_INFO(crop); @@ -85,6 +87,8 @@ Status AippUtils::ConvertAippParams2AippInfo(domi::AippOpParams *aipp_params, Ai AIPP_CONVERT_TO_AIPP_INFO_WITH_INDEX(var_reci_chn_1, 0); AIPP_CONVERT_TO_AIPP_INFO_WITH_INDEX(var_reci_chn_2, 0); AIPP_CONVERT_TO_AIPP_INFO_WITH_INDEX(var_reci_chn_3, 0); + AIPP_CONVERT_TO_AIPP_INFO(support_rotation); + AIPP_CONVERT_TO_AIPP_INFO(max_src_image_size); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/data_dumper.cc b/src/ge/graph/load/new_model_manager/data_dumper.cc index 7194264d..b94add80 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.cc +++ b/src/ge/graph/load/new_model_manager/data_dumper.cc @@ -171,6 +171,44 @@ void DataDumper::SaveOpDebugId(uint32_t task_id, uint32_t stream_id, void *op_de is_op_debug_ = is_op_debug; } +void DataDumper::SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, + uint32_t stream_id) { + GELOGD("Start SaveDumpOpInfo of task_id: %u, stream_id: %u", task_id, stream_id); + OpDescInfo op_desc_info; + op_desc_info.op_name = op->GetName(); + op_desc_info.task_id = task_id; + op_desc_info.stream_id = stream_id; + for (size_t i = 0; i < op->GetInputsSize(); ++i) { + GeTensorDesc input_desc = op->GetInputDesc(i); + op_desc_info.input_format.emplace_back(input_desc.GetFormat()); + op_desc_info.input_shape.emplace_back(input_desc.GetShape().GetDims()); + op_desc_info.input_data_type.emplace_back(input_desc.GetDataType()); + } + for (size_t j = 0; j < op->GetOutputsSize(); ++j) { + GeTensorDesc output_desc = op->GetOutputDesc(j); + op_desc_info.output_format.emplace_back(output_desc.GetFormat()); + op_desc_info.output_shape.emplace_back(output_desc.GetShape().GetDims()); + op_desc_info.output_data_type.emplace_back(output_desc.GetDataType()); + } + op_desc_info.input_addrs = ModelUtils::GetInputDataAddrs(model_param, op); + op_desc_info.output_addrs = ModelUtils::GetOutputDataAddrs(model_param, op); + + op_desc_info_.emplace_back(op_desc_info); +} + +bool DataDumper::GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const { + GELOGI("There are %zu op need to dump.", op_desc_info_.size()); + for (size_t index = 0; index < op_desc_info_.size(); ++index) { + OpDescInfo dump_op_info = op_desc_info_.at(index); + if (dump_op_info.task_id == task_id && dump_op_info.stream_id == stream_id) { + GELOGI("find exception op of task_id: %u, stream_id: %u.", task_id, stream_id); + op_desc_info = dump_op_info; + return true; + } + } + return false; +} + void DataDumper::SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, uintptr_t args) { if (op_desc == nullptr) { @@ -325,17 +363,24 @@ Status DataDumper::DumpOutputWithTask(const InnerDumpInfo &inner_dump_info, aicp // check dump output tensor desc is redirected by attr ATTR_DATA_DUMP_REF if (AttrUtils::GetStr(&output_desc, ATTR_DATA_DUMP_REF, node_name_index)) { GE_CHK_STATUS_RET(DumpRefOutput(inner_dump_info, output, i, node_name_index), "DumpRefOutput failed"); + task.mutable_output()->Add(std::move(output)); } else { - GE_IF_BOOL_EXEC( - IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i), - GELOGD("DumpOutputWithTask[%s] output[%zu] is l1 addr, skip it", inner_dump_info.op->GetName().c_str(), i); - continue;); - - const auto input_size = inner_dump_info.op->GetInputsSize(); - auto addr = inner_dump_info.args + (i + input_size) * kAddrLen; - GE_CHK_STATUS_RET(GenerateOutput(output, output_descs, addr, i), "Generate output failed"); + if (IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i)) { + GELOGI("[L1Fusion] DumpOutputWithTask[%s] output[%zu] is l1 addr.", inner_dump_info.op->GetName().c_str(), i); + int64_t output_size = 0; + if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get output size failed."); + return PARAM_INVALID; + } + GELOGI("Get output size of l1_fusion_dump is %ld", output_size); + GenerateOpBuffer(output_size, task); + } else { + const auto input_size = inner_dump_info.op->GetInputsSize(); + auto addr = inner_dump_info.args + (i + input_size) * kAddrLen; + GE_CHK_STATUS_RET(GenerateOutput(output, output_descs, addr, i), "Generate output failed"); + task.mutable_output()->Add(std::move(output)); + } } - task.mutable_output()->Add(std::move(output)); } return SUCCESS; } @@ -468,20 +513,38 @@ Status DataDumper::DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump:: // check dump input tensor desc is redirected by attr ATTR_DATA_DUMP_REF if (AttrUtils::GetStr(&input_descs.at(i), ATTR_DATA_DUMP_REF, node_name_index)) { GE_CHK_STATUS_RET(DumpRefInput(inner_dump_info, input, i, node_name_index), "DumpRefInput failed"); + task.mutable_input()->Add(std::move(input)); // normal dump without attr } else { - GE_IF_BOOL_EXEC(IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i), - GELOGD("DumpInput[%s] input[%zu] is l1 addr, skip it", inner_dump_info.op->GetName().c_str(), i); - continue;); - - auto addr = inner_dump_info.args + kAddrLen * i; - GE_CHK_STATUS_RET(GenerateInput(input, input_descs, addr, i), "Generate input failed"); + if (IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i)) { + GELOGI("[L1Fusion] DumpInput[%s] input[%zu] is l1 addr", inner_dump_info.op->GetName().c_str(), i); + int64_t input_size = 0; + if (AttrUtils::GetInt(input_descs.at(i), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { + GELOGI("Get aipp input size according to attr is %ld", input_size); + } else if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get input size failed."); + return PARAM_INVALID; + } + GELOGI("Get input size of l1_fusion_dump is %ld", input_size); + GenerateOpBuffer(input_size, task); + } else { + auto addr = inner_dump_info.args + kAddrLen * i; + GE_CHK_STATUS_RET(GenerateInput(input, input_descs, addr, i), "Generate input failed"); + task.mutable_input()->Add(std::move(input)); + } } - task.mutable_input()->Add(std::move(input)); } return SUCCESS; } +void DataDumper::GenerateOpBuffer(const int64_t &size, aicpu::dump::Task &task) { + aicpu::dump::OpBuffer op_buffer; + op_buffer.set_buffer_type(aicpu::dump::BufferType::L1); + op_buffer.set_address(reinterpret_cast(l1_fusion_addr_)); + op_buffer.set_size(size); + task.mutable_buffer()->Add(std::move(op_buffer)); +} + Status DataDumper::ExecuteLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info) { std::string proto_str; size_t proto_size = op_mapping_info.ByteSizeLong(); @@ -720,7 +783,7 @@ void DataDumper::PrintCheckLog(string &dump_list_key) { bool not_find_by_omname = model_list.find(om_name_) == model_list.end(); bool not_find_by_modelname = model_list.find(model_name_) == model_list.end(); dump_list_key = not_find_by_omname ? model_name_ : om_name_; - GELOGI("%zu op need dump in %s.", op_list_.size(), dump_list_key.c_str()); + GELOGI("%zu op need dump in known shape model %s.", op_list_.size(), dump_list_key.c_str()); if (model_list.find(DUMP_ALL_MODEL) == model_list.end()) { if (not_find_by_omname && not_find_by_modelname) { diff --git a/src/ge/graph/load/new_model_manager/data_dumper.h b/src/ge/graph/load/new_model_manager/data_dumper.h index 0648a8ce..cb5bbd41 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.h +++ b/src/ge/graph/load/new_model_manager/data_dumper.h @@ -30,6 +30,7 @@ #include "proto/op_mapping_info.pb.h" #include "runtime/mem.h" #include "task_info/task_info.h" +#include "framework/common/ge_types.h" namespace ge { class DataDumper { @@ -64,10 +65,14 @@ class DataDumper { void SetRefInfo(const std::map &ref_info) { ref_info_ = ref_info; }; + void SetL1FusionAddr(void *addr) { l1_fusion_addr_ = addr; }; + void SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond); void SaveDumpInput(const std::shared_ptr &node); + void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id); + // args is device memory stored first output addr void SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, uintptr_t args); void SaveEndGraphId(uint32_t task_id, uint32_t stream_id); @@ -81,6 +86,7 @@ class DataDumper { void SetDumpProperties(const DumpProperties &dump_properties) { dump_properties_ = dump_properties; } const DumpProperties &GetDumpProperties() const { return dump_properties_; } + bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const; private: void ReleaseDevMem(void **ptr) noexcept; @@ -100,6 +106,7 @@ class DataDumper { struct InnerDumpInfo; struct InnerInputMapping; + std::vector op_desc_info_; std::vector op_list_; uint32_t end_graph_task_id_ = 0; uint32_t end_graph_stream_id_ = 0; @@ -111,6 +118,7 @@ class DataDumper { uintptr_t loop_cond_; ComputeGraphPtr compute_graph_; std::map ref_info_; + void *l1_fusion_addr_ = nullptr; uint32_t op_debug_task_id_ = 0; uint32_t op_debug_stream_id_ = 0; @@ -135,6 +143,7 @@ class DataDumper { const uintptr_t &addr, size_t index); Status GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vistor &tensor_descs, const uintptr_t &addr, size_t index); + void GenerateOpBuffer(const int64_t &size, aicpu::dump::Task &task); }; struct DataDumper::InnerDumpInfo { uint32_t task_id; diff --git a/src/ge/graph/load/new_model_manager/davinci_model.cc b/src/ge/graph/load/new_model_manager/davinci_model.cc index 5af366a5..7daeb1b8 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model.cc @@ -84,6 +84,8 @@ const uint32_t kAddrLen = sizeof(void *); const int kDecimal = 10; const int kBytes = 8; const uint32_t kDataMemAlignSizeCompare = 64; +const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024; +const uint32_t kDumpFlagOfL1Fusion = 0; const char *const kDefaultBatchLable = "Batch_default"; inline bool IsDataOp(const std::string &node_type) { @@ -97,7 +99,6 @@ inline bool IsNoTaskAndDumpNeeded(const OpDescPtr &op_desc) { } // namespace std::mutex DavinciModel::tvm_bin_mutex_; -std::set DavinciModel::tvm_bin_kernel_; DavinciModel::DavinciModel(int32_t priority, const std::shared_ptr &listener) : weights_mem_base_(nullptr), @@ -179,6 +180,10 @@ DavinciModel::~DavinciModel() { FreeFeatureMapMem(); + if (l1_fusion_addr_ != nullptr) { + GE_CHK_RT(rtFree(l1_fusion_addr_)); + } + if (rt_model_handle_ != nullptr) { GE_CHK_RT(rtModelDestroy(rt_model_handle_)); rt_model_handle_ = nullptr; @@ -305,7 +310,7 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p if (weight_ptr == nullptr) { weights_mem_base_ = MallocWeightsMem(weights_size); if (weights_mem_base_ == nullptr) { - GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc weight memory failed. size: %zu", weights_size); + GELOGE(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, "Alloc weight memory failed. size: %zu", weights_size); return GE_EXEC_ALLOC_WEIGHT_MEM_FAILED; } is_inner_weight_base_ = true; @@ -367,7 +372,7 @@ void DavinciModel::InitRuntimeParams() { session_id_ = runtime_param_.session_id; GELOGI( - "InitRuntimeParams(), session_id:%u, stream_num:%lu, event_num:%u, label_num:%u, " + "InitRuntimeParams(), session_id:%lu, stream_num:%u, event_num:%u, label_num:%u, " "logic_mem_base:0x%lx, logic_weight_base:0x%lx, logic_var_base:0x%lx, " "memory_size:%lu, weight_size:%lu, var_size:%lu", runtime_param_.session_id, runtime_param_.stream_num, runtime_param_.event_num, runtime_param_.label_num, @@ -401,6 +406,7 @@ void DavinciModel::CheckHasHcomOp() { /// Status DavinciModel::BindModelStream() { // Stream not in active_stream_indication_ is active stream. + is_stream_list_bind_ = false; if ((!input_queue_ids_.empty() || !output_queue_ids_.empty()) || (deploy_type_ == AICPU_DEPLOY_CROSS_THREAD)) { for (size_t i = 0; i < stream_list_.size(); ++i) { if (active_stream_indication_.count(i) == 0) { @@ -419,7 +425,7 @@ Status DavinciModel::BindModelStream() { GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, stream_list_[i], RT_HEAD_STREAM)); } } - + is_stream_list_bind_ = true; return SUCCESS; } @@ -600,6 +606,12 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size // create model_handle to load model GE_CHK_RT_RET(rtModelCreate(&rt_model_handle_, 0)); GE_CHK_RT_RET(rtModelGetId(rt_model_handle_, &runtime_model_id_)); + // malloc 2M for dump l1fusion op + GE_CHK_RT_RET(rtMalloc(&l1_fusion_addr_, kDumpL1FusionOpMByteSize, RT_MEMORY_DDR)); + + // send l1fusion dump addr to rts + GE_CHK_RT_RET(rtDumpAddrSet(rt_model_handle_, l1_fusion_addr_, kDumpL1FusionOpMByteSize, kDumpFlagOfL1Fusion)); + // inference will use default graph_id 0; runtime_param_.graph_id = compute_graph->GetGraphID(); @@ -748,11 +760,18 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { typedef Status (DavinciModel::*OpDescCall)(const OpDescPtr &); static std::map op_desc_handle = { - {VARIABLE, &DavinciModel::InitVariable}, {CONSTANTOP, &DavinciModel::InitConstant}, - {STREAMACTIVE, &DavinciModel::InitStreamActive}, {STREAMSWITCH, &DavinciModel::InitStreamSwitch}, - {STREAMSWITCHN, &DavinciModel::InitStreamSwitchN}, {LABELSET, &DavinciModel::InitLabelSet}, + {VARIABLE, &DavinciModel::InitVariable}, + {CONSTANTOP, &DavinciModel::InitConstant}, + {STREAMACTIVE, &DavinciModel::InitStreamActive}, + {STREAMSWITCH, &DavinciModel::InitStreamSwitch}, + {STREAMSWITCHN, &DavinciModel::InitStreamSwitchN}, + {LABELSET, &DavinciModel::InitLabelSet}, + {CASE, &DavinciModel::InitCase}, }; + GE_CHK_STATUS_RET(InitInputOutputForDynamic(compute_graph), "InitInputOutputForDynamic failed."); + + map data_by_index; auto nodes = compute_graph->GetAllNodes(); const TBEKernelStore &tbekernel_store = ge_model_->GetTBEKernelStore(); for (size_t i = 0; i < nodes.size(); i++) { @@ -770,7 +789,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { GE_TIMESTAMP_ADD(LoadTBEKernelBinToOpDesc); if (IsDataOp(op_desc->GetType())) { - if (InitDataOp(node, data_op_index) != SUCCESS) { + if (InitDataOp(node, data_op_index, data_by_index) != SUCCESS) { GELOGE(PARAM_INVALID, "Data init failed, Name: %s", op_desc->GetName().c_str()); return PARAM_INVALID; } @@ -839,21 +858,44 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { GE_TIMESTAMP_ADD(InitTbeHandle); } + AdjustDataOpList(data_by_index); GE_TIMESTAMP_CALLNUM_END(LoadTBEKernelBinToOpDesc, "GraphLoader::LoadTBEKernelBinToOpDesc."); GE_TIMESTAMP_CALLNUM_END(InitTbeHandle, "GraphLoader::InitTbeHandle."); return SUCCESS; } +Status DavinciModel::InitInputOutputForDynamic(const ComputeGraphPtr &compute_graph) { + if (!known_node_) return SUCCESS; + // for dynamic shape + auto direct_nodes = compute_graph->GetDirectNode(); + for (size_t i = 0; i < direct_nodes.size(); i++) { + auto node = direct_nodes.at(i); + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(PARAM_INVALID, "op_desc is null."); + return PARAM_INVALID; + } + if (IsDataOp(op_desc->GetType())) { + GELOGD("init data op %s", op_desc->GetName().c_str()); + data_op_list_.push_back(op_desc); + } + if (op_desc->GetType() == NETOUTPUT) { + GELOGD("init netouput op %s", op_desc->GetName().c_str()); + output_op_list_.push_back(op_desc); + } + } + return SUCCESS; +} + /// @ingroup ge /// @brief Data Op Initialize. /// @param [in] NodePtr: Data Op. /// @param [in/out] data_op_index: NetOutput addr size info. /// @return Status -Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { +Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index, map &data_by_index) { // op_desc Checked by Init: Data, valid. auto op_desc = node->GetOpDesc(); if (known_node_) { - data_op_list_.push_back(op_desc); return SUCCESS; } uint32_t parent_index = 0; // Ignore subgraph Data Node. @@ -885,6 +927,7 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { return PARAM_INVALID; } new_input_data_info_[data_index] = zero_copy_offset; + data_by_index[data_index] = op_desc; for (size_t index = 0; index < virtual_addr_list.size(); ++index) { void *addr = virtual_addr_list.at(index); @@ -904,6 +947,24 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { return SUCCESS; } +/// +/// @ingroup ge +/// @brief Sort Data op list by index. +/// @param [in] data_by_index: map of Data Op. +/// @return +/// +void DavinciModel::AdjustDataOpList(const map &data_by_index) { + if (data_by_index.size() != data_op_list_.size()) { + GELOGW("Data map size: %zu, Data list size: %zu.", data_by_index.size(), data_op_list_.size()); + return; + } + + data_op_list_.clear(); + for (auto &item : data_by_index) { + data_op_list_.emplace_back(item.second); + } +} + /// /// @ingroup ge /// @brief input zero copy node Initialize. @@ -946,7 +1007,6 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { auto op_desc = node->GetOpDesc(); // excludes the function op sub graph, e.g. case,if if (known_node_) { - output_op_list_.push_back(op_desc); return SUCCESS; } ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); @@ -989,9 +1049,6 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { new_output_data_info_[num + idx] = zero_copy_offset; void *addr = virtual_addr_list.at(idx); int64_t input_offset = input_offset_list.at(idx); - if (new_output_outside_addrs_.find(addr) != new_output_outside_addrs_.end()) { - continue; - } vector tensor_addrs; zero_copy_offset.SetOutputOutsideAddrs(input_offset, fusion_flag, addr, tensor_addrs); auto rslt = new_output_outside_addrs_.insert(std::pair(addr, zero_copy_offset)); @@ -1464,6 +1521,17 @@ void DavinciModel::GetCombinedDynamicDims(std::vector> &bat batch_info = combined_batch_info_; } +/// +/// @ingroup ge +/// @brief Get user designate shape order +/// @param [out] user_input_shape_order +/// @return None +/// +void DavinciModel::GetUserDesignateShapeOrder(std::vector &user_input_shape_order) const { + user_input_shape_order.clear(); + user_input_shape_order = user_designate_shape_order_; +} + /// /// @ingroup ge /// @brief Get AIPP input info @@ -1475,7 +1543,7 @@ Status DavinciModel::GetAIPPInfo(uint32_t index, AippConfigInfo &aipp_info) { GE_CHK_BOOL_RET_STATUS(index < data_op_list_.size(), PARAM_INVALID, "Index %u is invalid.", index); OpDescPtr data_op = data_op_list_[index]; if (!data_op->HasAttr(ATTR_NAME_AIPP)) { - GELOGE(GE_AIPP_NOT_EXIST, "GetAIPPInfo: there is not AIPP related with index %u.", index); + GELOGW("GetAIPPInfo: there is not AIPP related with index %u.", index); return GE_AIPP_NOT_EXIST; } @@ -1488,10 +1556,6 @@ Status DavinciModel::GetAIPPInfo(uint32_t index, AippConfigInfo &aipp_info) { GE_CHK_STATUS_RET(OpUtils::ConvertAippParams(aipp_attr, aipp_params.get()), "get aipp params failed"); GELOGI("GetAIPPInfo: node data: %s, type: %s, current index: %u, current node related input rank: %u", data_op->GetName().c_str(), data_op->GetType().c_str(), index, aipp_params->related_input_rank()); - if (aipp_params->aipp_mode() == domi::AippOpParams::dynamic) { - GELOGI("GetAIPPInfo, dynamic Aipp is not support to query temporarily."); - return GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY; - } GE_CHK_STATUS_RET(AippUtils::ConvertAippParams2AippInfo(aipp_params.get(), aipp_info), "convert aipp params to aipp config info failed"); @@ -1563,51 +1627,51 @@ Status DavinciModel::GetInputOutputDescInfoForZeroCopy(vector &model_input_dims, Format &format, + InputOutputDescInfo &input) { uint32_t n, c, h, w; n = format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N; c = format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C; h = format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H; w = format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W; + if (model_input_dims.size() == static_cast(NORMAL_TENSOR_SIZE)) { + input.shape_info.num = model_input_dims[n]; + input.shape_info.height = model_input_dims[h]; + input.shape_info.width = model_input_dims[w]; + input.shape_info.channel = model_input_dims[c]; + } + for (size_t k = 0; k < model_input_dims.size(); ++k) { + input.shape_info.dims.push_back(model_input_dims[k]); + } + return; +} + +void DavinciModel::CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input) { if (is_new_model_desc_ && op_desc->HasAttr(ATTR_NAME_INPUT_DIMS)) { // When static aipp is set, need to get the model input dims which processed by aipp vector model_input_dims; (void)AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_DIMS, model_input_dims); - if (model_input_dims.size() == static_cast(NORMAL_TENSOR_SIZE)) { - input.shape_info.num = model_input_dims[n]; - input.shape_info.height = model_input_dims[h]; - input.shape_info.width = model_input_dims[w]; - input.shape_info.channel = model_input_dims[c]; - } - for (size_t k = 0; k < model_input_dims.size(); ++k) { - input.shape_info.dims.push_back(model_input_dims[k]); - } - is_new_model_desc_ = false; + SetInputDimsInfo(model_input_dims, format, input); return; } - - if (!op_desc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { - if (op_desc->GetInputDescPtr(0)->GetShape().GetDimNum() == static_cast(NORMAL_TENSOR_SIZE)) { - input.shape_info.num = op_desc->GetInputDescPtr(0)->GetShape().GetDim(n); - input.shape_info.height = op_desc->GetInputDescPtr(0)->GetShape().GetDim(h); - input.shape_info.width = op_desc->GetInputDescPtr(0)->GetShape().GetDim(w); - input.shape_info.channel = op_desc->GetInputDescPtr(0)->GetShape().GetDim(c); - } - for (size_t k = 0; k < op_desc->GetInputDescPtr(0)->GetShape().GetDimNum(); k++) { - input.shape_info.dims.push_back(op_desc->GetInputDescPtr(0)->GetShape().GetDim(k)); - } + // judge if this data is linked dynamic aipp first, multiply batch has been considered + if (op_desc->HasAttr("_dynamic_aipp_input_dims")) { + vector dynamic_aipp_input_dims; + (void)AttrUtils::GetListInt(op_desc, "_dynamic_aipp_input_dims", dynamic_aipp_input_dims); + SetInputDimsInfo(dynamic_aipp_input_dims, format, input); + return; } else { - vector origin_input_dims; - (void)AttrUtils::GetListInt(op_desc, ATTR_MBATCH_ORIGIN_INPUT_DIMS, origin_input_dims); - if (origin_input_dims.size() == static_cast(NORMAL_TENSOR_SIZE)) { - input.shape_info.num = origin_input_dims[n]; - input.shape_info.height = origin_input_dims[h]; - input.shape_info.width = origin_input_dims[w]; - input.shape_info.channel = origin_input_dims[c]; - } - for (size_t k = 0; k < origin_input_dims.size(); ++k) { - input.shape_info.dims.push_back(origin_input_dims[k]); + // judge if this data is multiply batch + if (!op_desc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + vector input_dims = op_desc->GetInputDescPtr(0)->GetShape().GetDims(); + SetInputDimsInfo(input_dims, format, input); + return; + } else { + vector origin_input_dims; + (void)AttrUtils::GetListInt(op_desc, ATTR_MBATCH_ORIGIN_INPUT_DIMS, origin_input_dims); + SetInputDimsInfo(origin_input_dims, format, input); + return; } } } @@ -1630,6 +1694,8 @@ Status DavinciModel::GetInputDescInfo(vector &input_desc, s formats.push_back(format); input_desc.push_back(input); } + // cause GetInputDescInfo called not only once, set is_new_model_desc_ to false after calc the model input dims + is_new_model_desc_ = false; return SUCCESS; } @@ -2106,22 +2172,24 @@ Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, r return FAILED; } + if ((kind == RT_MEMCPY_DEVICE_TO_DEVICE) && (copy_only_addrs_.count(output.second.GetBasicAddr()) == 0)) { + continue; // Skip: Feed by zero copy. + } + DataBuffer &buffer = blobs[output.first]; uint64_t mem_size = static_cast(output.second.GetDataSize()); if ((buffer.length == 0) || (mem_size == 0)) { GELOGI("Length of data is zero, No need copy. output tensor index=%u", output.first); continue; } - if (buffer.length < mem_size) { + if (is_dynamic_) { + GELOGI("No need to check output data size."); + } else if (buffer.length < mem_size) { GELOGE(FAILED, "Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); return FAILED; } else if (buffer.length > mem_size) { GELOGW("Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); } - - if ((kind == RT_MEMCPY_DEVICE_TO_DEVICE) && (copy_only_addrs_.count(output.second.GetBasicAddr()) == 0)) { - continue; // Skip: Feed by zero copy. - } uint64_t data_size = output.second.GetDataSize(); uint64_t buffer_length = buffer.length; void *buffer_addr = reinterpret_cast(reinterpret_cast(buffer.data)); @@ -2564,10 +2632,12 @@ Status DavinciModel::ModelRunStop() { void DavinciModel::UnbindTaskSinkStream() { // unbinding hcom stream UnbindHcomStream(); - for (size_t i = 0; i < stream_list_.size(); i++) { - // unbind rt_model_handle and streams - GE_LOGW_IF(rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE, - "Unbind stream from model failed! Index: %zu", i); + if (is_stream_list_bind_) { + for (size_t i = 0; i < stream_list_.size(); i++) { + // unbind rt_model_handle and streams + GE_LOGW_IF(rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE, + "Unbind stream from model failed! Index: %zu", i); + } } if (is_inner_model_stream_) { @@ -2610,11 +2680,7 @@ Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const return SUCCESS; } const vector addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, output_op_list_[kDataIndex]); - if (outputs.size() > addr_list.size()) { - GELOGE(FAILED, "output data addr %u should less than output op number %u.", outputs.size(), addr_list.size()); - return FAILED; - } - for (size_t i = 0; i < addr_list.size(); ++i) { + for (size_t i = 0; i < addr_list.size() && i < outputs.size(); ++i) { knonw_output_data_info_[addr_list[i]] = outputs[i]; GELOGI("DavinciModel::CreateKnownZeroCopyMap output %d,v addr %p,p addr %p .", i, addr_list[i], outputs[i]); } @@ -2755,19 +2821,21 @@ Status DavinciModel::DistributeTask() { } const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); + GELOGI("there are %zu task need to save.", task_list_.size()); for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { auto &task = task_list_.at(task_index); GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index); // for data dump - if (reinterpret_cast(task->GetDumpArgs()) != nullptr) { - auto op_index = std::max(model_task_def->task(task_index).kernel().context().op_index(), - model_task_def->task(task_index).kernel_ex().op_index()); - OpDescPtr op = GetOpByIndex(op_index); - if (op == nullptr) { - GELOGE(PARAM_INVALID, "Op index %u is null, op list size %zu.", op_index, op_list_.size()); - return PARAM_INVALID; - } + auto op_index = std::max(model_task_def->task(task_index).kernel().context().op_index(), + model_task_def->task(task_index).kernel_ex().op_index()); + OpDescPtr op = GetOpByIndex(op_index); + if (op == nullptr) { + GELOGE(PARAM_INVALID, "Op index %u is null, op list size %zu.", op_index, op_list_.size()); + return PARAM_INVALID; + } + SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId()); + if (reinterpret_cast(task->GetDumpArgs()) != nullptr) { bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo(); if (call_dump) { SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); @@ -2873,7 +2941,7 @@ void DavinciModel::DisableZeroCopy(const void *addr) { void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs, const void *info, void *args, size_t size, size_t offset) { // Internal call has ensured that op_desc is not nullptr - GELOGI("[ZCPY] SetZeroCopyAddr for %s.", op_desc->GetName().c_str()); + GELOGD("[ZCPY] SetZeroCopyAddr for %s.", op_desc->GetName().c_str()); size_t nums = outside_addrs.size(); ZeroCopyTask zero_copy_task(op_desc->GetName(), static_cast(args), size); for (size_t i = 0; i < nums; ++i) { @@ -2994,7 +3062,7 @@ Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &outp } for (ZeroCopyTask &task : zero_copy_tasks_) { - GE_CHK_STATUS_RET(task.DistributeParam(is_async_mode_ ? rt_model_stream_ : nullptr), "[ZCPY] Update args failed."); + GE_CHK_STATUS_RET(task.DistributeParam(is_async_mode_, rt_model_stream_), "[ZCPY] Update args failed."); } output_data.index = input_data.index; @@ -3106,7 +3174,6 @@ const char *DavinciModel::GetRegisterStub(const string &binfile, const string &s } else { binfile_key = session_graph_id + "_" + binfile; } - std::lock_guard lock(tvm_bin_mutex_); auto it = tvm_bin_kernel_.find(binfile_key); if (it != tvm_bin_kernel_.end()) { return it->c_str(); @@ -3242,7 +3309,6 @@ void DavinciModel::StoreTbeHandle(const std::string &handle_key) { // Online mode FE may call rtFunctionRegister. TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); - // Need protection of tvm_bin_mutex_. auto it = used_tbe_handle_map_.find(handle_key); if (it != used_tbe_handle_map_.end()) { // GE registered, increase reference. @@ -3262,9 +3328,9 @@ void DavinciModel::StoreTbeHandle(const std::string &handle_key) { void DavinciModel::CleanTbeHandle() { TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); - std::lock_guard lock(tvm_bin_mutex_); kernel_store.EraseTBEHandle(used_tbe_handle_map_); used_tbe_handle_map_.clear(); + tvm_bin_kernel_.clear(); } /// @@ -3315,21 +3381,26 @@ Status DavinciModel::InitStreamSwitchN(const OpDescPtr &op_desc) { GELOGI("StreamSwitchNOp node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); } - (void)AttrUtils::GetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_); - - batch_info_.clear(); - combined_batch_info_.clear(); uint32_t batch_num = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { GELOGE(FAILED, "Failed to get attr ATTR_NAME_BATCH_NUM, StreamSwitchN: %s.", op_desc->GetName().c_str()); return FAILED; } - for (uint32_t i = 0; i < batch_num; i++) { + return SetDynamicBatchInfo(op_desc, batch_num); +} + +Status DavinciModel::SetDynamicBatchInfo(const OpDescPtr &op_desc, uint32_t batch_num) { + batch_info_.clear(); + combined_batch_info_.clear(); + + (void)AttrUtils::GetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_); + (void)AttrUtils::GetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, user_designate_shape_order_); + for (uint32_t i = 0; i < batch_num; ++i) { std::vector batch_shape; const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); if (!AttrUtils::GetListInt(op_desc, attr_name, batch_shape)) { - GELOGE(FAILED, "Failed to get attr ATTR_NAME_PRED_VALUE, StreamSwitchN: %s.", op_desc->GetName().c_str()); + GELOGE(FAILED, "Get attr ATTR_NAME_PRED_VALUE failed, Node: %s", op_desc->GetName().c_str()); batch_info_.clear(); return FAILED; } @@ -3344,6 +3415,16 @@ Status DavinciModel::InitStreamSwitchN(const OpDescPtr &op_desc) { return SUCCESS; } +Status DavinciModel::InitCase(const OpDescPtr &op_desc) { + uint32_t batch_num = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { + GELOGI("Not multi-batch Node: %s", op_desc->GetName().c_str()); + return SUCCESS; + } + + return SetDynamicBatchInfo(op_desc, batch_num); +} + bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { for (auto out_anchor : var_node->GetAllOutDataAnchors()) { GE_RT_FALSE_CHECK_NOTNULL(out_anchor); @@ -3406,12 +3487,13 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); - if (!input_data.is_dynamic_batch) { + is_dynamic_ = input_data.is_dynamic_batch; + if (!is_dynamic_) { zero_copy_batch_label_addrs_.clear(); } GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_PRE_PROC_START)); - Status ret = CopyModelData(input_data, output_data, input_data.is_dynamic_batch); + Status ret = CopyModelData(input_data, output_data, is_dynamic_); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Copy input data to model failed. model id: %u", model_id_); @@ -3587,6 +3669,7 @@ void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) { data_dumper_.SetOmName(om_name_); data_dumper_.SetComputeGraph(compute_graph); data_dumper_.SetRefInfo(saved_task_addrs_); + data_dumper_.SetL1FusionAddr(l1_fusion_addr_); int32_t device_id = 0; rtError_t rt_ret = rtGetDevice(&device_id); @@ -3627,19 +3710,9 @@ void DavinciModel::PushHcclStream(rtStream_t value) { all_hccl_stream_list_.push_back(value); } -void DavinciModel::CreateHcclFollowStream(rtStream_t stream, int64_t remain_cap) { +void DavinciModel::SaveHcclFollowStream(int64_t main_stream_id, rtStream_t stream) { std::lock_guard lock(capacity_of_stream_mutex_); - capacity_of_stream_.emplace_back(make_pair(stream, remain_cap)); -} - -void DavinciModel::ReuseHcclFollowStream(int64_t remain_cap, int64_t &index) { - std::lock_guard lock(capacity_of_stream_mutex_); - if (remain_cap == 0) { - capacity_of_stream_.erase(capacity_of_stream_.begin() + index); - } else { - capacity_of_stream_.at(index).second = remain_cap; - index++; - } + main_follow_stream_mapping_[main_stream_id].emplace_back(stream); } Status DavinciModel::GetComputeGraphInfo(const ComputeGraphPtr &graph, vector &graph_desc_info) { @@ -3756,8 +3829,7 @@ Status DavinciModel::GetAllAippInputOutputDims(uint32_t index, std::vectorGetInputDescPtr(kDataIndex)), data_input_size); GELOGD( "GetAllAippInputOutputDims related Data[%d]: tensor_name is %s, dim_num is %u, tensor_size: %zu, format: " - "%s, " - "data_type: %s, shape: %s .", + "%s, data_type: %s, shape: %s .", index, data_op->GetName().c_str(), data_input_desc->GetShape().GetDimNum(), data_input_size, TypeUtils::FormatToSerialString(data_input_desc->GetFormat()).c_str(), TypeUtils::DataTypeToSerialString(data_input_desc->GetDataType()).c_str(), diff --git a/src/ge/graph/load/new_model_manager/davinci_model.h b/src/ge/graph/load/new_model_manager/davinci_model.h index cb7e4528..e77c5510 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.h +++ b/src/ge/graph/load/new_model_manager/davinci_model.h @@ -184,10 +184,10 @@ class DavinciModel { size_t TotalMemSize() const { return runtime_param_.mem_size; } // model name - string Name() { return name_; } + string Name() const { return name_; } // om_name - string OmName() { return om_name_; } + string OmName() const { return om_name_; } // version uint32_t Version() const { return version_; } @@ -268,7 +268,7 @@ class DavinciModel { /// @brief For TVM Op, avoid Addr Reuse. /// @return void* /// - static const char *GetRegisterStub(const string &tvm_binfile_key, const string &session_graph_model_id = ""); + const char *GetRegisterStub(const string &tvm_binfile_key, const string &session_graph_model_id = ""); /// /// @ingroup ge @@ -299,6 +299,8 @@ class DavinciModel { /// void GetCombinedDynamicDims(std::vector> &batch_info) const; + void GetUserDesignateShapeOrder(std::vector &user_input_shape_order) const; + void GetCurShape(std::vector &batch_info, int32_t &dynamic_type); void GetModelAttr(std::vector &dynamic_output_shape_info); @@ -440,6 +442,10 @@ class DavinciModel { Status SinkTimeProfile(const InputData ¤t_data); + void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { + data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id); + } + void SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, uintptr_t args) { data_dumper_.SaveDumpTask(task_id, stream_id, op_desc, args); } @@ -449,9 +455,8 @@ class DavinciModel { DavinciModel(const DavinciModel &model) = delete; - const vector> &GetHcclFolowStream() { return capacity_of_stream_; } - void CreateHcclFollowStream(rtStream_t stream, int64_t remain_cap); - void ReuseHcclFollowStream(int64_t remain_cap, int64_t &index); + const map> &GetHcclFolowStream() { return main_follow_stream_mapping_; } + void SaveHcclFollowStream(int64_t main_stream_id, rtStream_t stream); void InitRuntimeParams(); Status InitVariableMem(); @@ -500,6 +505,16 @@ class DavinciModel { void SetDumpProperties(const DumpProperties &dump_properties) { data_dumper_.SetDumpProperties(dump_properties); } const DumpProperties &GetDumpProperties() const { return data_dumper_.GetDumpProperties(); } + void SetMemcpyOffsetAndAddr(map &memcpy_4g_offset_addr) { + memcpy_4g_offset_addr_.insert(memcpy_4g_offset_addr.begin(), memcpy_4g_offset_addr.end()); + } + const map &GetMemcpyOffsetAndAddr() const { return memcpy_4g_offset_addr_; } + + bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const { + return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info); + } + Status InitInputOutputForDynamic(const ComputeGraphPtr &compute_graph); + private: // memory address of weights uint8_t *weights_mem_base_; @@ -575,6 +590,8 @@ class DavinciModel { void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input); + void SetInputDimsInfo(const vector &model_input_dims, Format &format, InputOutputDescInfo &input); + Status GetInputDescInfo(vector &input_desc, std::vector &formats); Status InitTaskInfo(domi::ModelTaskDef &modelTaskInfo); @@ -619,7 +636,15 @@ class DavinciModel { /// @param [in/out] data_op_index: NetOutput addr size info. /// @return Status /// - Status InitDataOp(const NodePtr &node, uint32_t &data_op_index); + Status InitDataOp(const NodePtr &node, uint32_t &data_op_index, map &data_by_index); + + /// + /// @ingroup ge + /// @brief Sort Data op list by index. + /// @param [in] data_by_index: map of Data Op. + /// @return + /// + void AdjustDataOpList(const map &data_by_index); /// /// @ingroup ge @@ -666,6 +691,15 @@ class DavinciModel { Status InitStreamSwitchN(const OpDescPtr &op_desc); + /// + /// @ingroup ge + /// @brief Case Op Init. + /// @return Status + /// + Status InitCase(const OpDescPtr &op_desc); + + Status SetDynamicBatchInfo(const OpDescPtr &op_desc, uint32_t batch_num); + /// /// @ingroup ge /// @brief TVM Op Init. @@ -840,7 +874,7 @@ class DavinciModel { // for reuse hccl_follow_stream std::mutex capacity_of_stream_mutex_; - std::vector> capacity_of_stream_; + std::map> main_follow_stream_mapping_; vector event_list_; @@ -866,6 +900,7 @@ class DavinciModel { bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_. + bool is_stream_list_bind_{false}; bool is_pure_head_stream_{false}; rtStream_t rt_head_stream_{nullptr}; rtStream_t rt_entry_stream_{nullptr}; @@ -891,8 +926,8 @@ class DavinciModel { std::set hcom_streams_; RuntimeParam runtime_param_; - static std::mutex tvm_bin_mutex_; // lock for tvm maps. - static std::set tvm_bin_kernel_; + static std::mutex tvm_bin_mutex_; + std::set tvm_bin_kernel_; std::map used_tbe_handle_map_; @@ -906,6 +941,7 @@ class DavinciModel { uint64_t iterator_count_; bool is_l1_fusion_enable_; std::map saved_task_addrs_; + void *l1_fusion_addr_ = nullptr; bool known_node_ = false; uint32_t total_args_size_ = 0; @@ -921,7 +957,9 @@ class DavinciModel { vector> batch_info_; std::vector> combined_batch_info_; + vector user_designate_shape_order_; int32_t dynamic_type_ = 0; + bool is_dynamic_ = false; vector batch_size_; // key: input tensor name, generally rts op; @@ -938,6 +976,8 @@ class DavinciModel { void *op_debug_addr_ = nullptr; void *p2p_debug_addr_ = nullptr; bool is_new_model_desc_{false}; + + std::map memcpy_4g_offset_addr_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ diff --git a/src/ge/graph/load/new_model_manager/model_manager.cc b/src/ge/graph/load/new_model_manager/model_manager.cc index 51b5b028..33e39847 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.cc +++ b/src/ge/graph/load/new_model_manager/model_manager.cc @@ -20,6 +20,7 @@ #include "common/l2_cache_optimize.h" #include "common/profiling/profiling_manager.h" +#include "common/dump/dump_manager.h" #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" @@ -172,7 +173,7 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { return GE_EXEC_MODEL_ID_INVALID; } uint64_t session_id = it->second->GetSessionId(); - GELOGI("Destroy aicpu session for infer, session id is %u.", session_id); + GELOGI("Destroy aicpu session for infer, session id is %lu.", session_id); DestroyAicpuSession(session_id); return SUCCESS; } @@ -259,7 +260,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrCheckIsUnknownShape(is_shape_unknown), "CheckIsUnknownShape failed, model id:%u", model_id); - if (is_shape_unknown) { + if (is_shape_unknown || GetContext().GetHostExecFlag()) { return DoLoadHybridModelOnline(model_id, ge_root_model, listener); } @@ -729,6 +730,22 @@ Status ModelManager::GetCombinedDynamicDims(const uint32_t model_id, vector &user_input_shape_order) { + auto davinci_model = GetModel(model_id); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, + "GetUserDesignateShapeOrder Failed, Invalid Model ID %u!", model_id) + davinci_model->GetUserDesignateShapeOrder(user_input_shape_order); + return SUCCESS; +} + Status ModelManager::GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type) { std::shared_ptr davinci_model = GetModel(model_id); GE_CHECK_NOTNULL(davinci_model); @@ -831,7 +848,11 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model } davinci_model->SetDeviceId(device_id); davinci_model->SetOmName(model.om_name); - davinci_model->SetDumpProperties(dump_properties_); + if (DumpManager::GetInstance().IsDumpOpen()) { + davinci_model->SetDumpProperties(DumpManager::GetInstance().GetDumpProperties()); + } else { + davinci_model->SetDumpProperties(dump_properties_); + } /// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail. /// These session_ids come from the same model, so the values of session_id are the same. @@ -1070,4 +1091,19 @@ ge::Status ModelManager::SyncExecuteModel(uint32_t model_id, const vectorExecute(inputs, outputs); } + +Status ModelManager::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) { + for (const auto &model : model_map_) { + auto davinci_model = model.second; + if (davinci_model->GetDeviceId() == device_id) { + GELOGI("Start to GetOpDescInfo of device_id: %u.", device_id); + if (davinci_model->GetOpDescInfo(stream_id, task_id, op_desc_info)) { + GELOGI("Find specific node of stream_id: %u, task_id: %u.", stream_id, task_id); + return SUCCESS; + } + } + } + return FAILED; +} + } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_manager.h b/src/ge/graph/load/new_model_manager/model_manager.h index 153d324d..a25b56a8 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.h +++ b/src/ge/graph/load/new_model_manager/model_manager.h @@ -201,6 +201,15 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// ge::Status GetCombinedDynamicDims(const uint32_t model_id, std::vector> &batch_info); + /// + /// @ingroup ge + /// @brief Get user designate shape order + /// @param [in] model_id + /// @param [out] user_input_shape_order + /// @return execute result + /// + Status GetUserDesignateShapeOrder(const uint32_t model_id, std::vector &user_input_shape_order); + /// /// @ingroup ge /// @brief Get AIPP info @@ -263,6 +272,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { std::vector &output_dims); bool IsDynamicShape(uint32_t model_id); + ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); private: /// diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index cb8cfed6..98d1d5a4 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -28,7 +28,6 @@ namespace { const uint32_t kMaxTaskOfStream = 200; } -uint32_t HcclTaskInfo::max_node_of_hccl_stream_ = 0; std::mutex HcclTaskInfo::hccl_follow_stream_mutex_; HcclTaskInfo::~HcclTaskInfo() { @@ -41,7 +40,6 @@ HcclTaskInfo::~HcclTaskInfo() { } davinci_model_ = nullptr; ops_kernel_store_ = nullptr; - max_node_of_hccl_stream_ = 0; args_ = nullptr; } Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { @@ -133,45 +131,39 @@ Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciM } std::lock_guard lock(hccl_follow_stream_mutex_); - if (max_node_of_hccl_stream_ == 0) { - uint32_t max_stream_count; - uint32_t max_task_count; - ret = rtGetMaxStreamAndTask(RT_NORMAL_STREAM, &max_stream_count, &max_task_count); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Get max stream and task count by rts failed."); - return RT_ERROR_TO_GE_STATUS(ret); - } - max_node_of_hccl_stream_ = max_task_count / kMaxTaskOfStream; - } + int64_t main_stream_id = op_desc->GetStreamId(); + const std::map> &main_follow_stream_mapping = davinci_model->GetHcclFolowStream(); - if (static_cast(hccl_stream_num) <= davinci_model->GetHcclFolowStream().size()) { - GELOGI("capacity of follow stream is enough to be reused."); - ReuseStream(hccl_stream_num, davinci_model); + if (main_follow_stream_mapping.find(main_stream_id) != main_follow_stream_mapping.end()) { + const std::vector &follow_stream_usage = main_follow_stream_mapping.at(main_stream_id); + if (static_cast(hccl_stream_num) <= follow_stream_usage.size()) { + GELOGI("capacity of follow stream is enough to be reused."); + for (int64_t i = 0; i < hccl_stream_num; i++) { + hccl_stream_list_.emplace_back(follow_stream_usage.at(i)); + } + } else { + GELOGI("need to reuse follow stream and create new follow stream."); + size_t created_stream_num = follow_stream_usage.size(); + hccl_stream_list_ = follow_stream_usage; + ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model, main_stream_id); + if (ret != SUCCESS) { + GELOGE(RT_FAILED, "Create hccl stream failed."); + return RT_ERROR_TO_GE_STATUS(ret); + } + } + GELOGI("Initialize hccl slave stream success, hcclStreamNum =%ld", hccl_stream_num); } else { - GELOGI("need to reuse follow stream and create new follow stream."); - size_t created_stream_num = davinci_model->GetHcclFolowStream().size(); - ReuseStream(created_stream_num, davinci_model); - ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model); + GELOGI("need to create follow stream for %s with new mainstream %ld.", op_desc->GetName().c_str(), main_stream_id); + ret = CreateStream(hccl_stream_num, davinci_model, main_stream_id); if (ret != SUCCESS) { GELOGE(RT_FAILED, "Create hccl stream failed."); return RT_ERROR_TO_GE_STATUS(ret); } } - GELOGI("Initialize hccl slave stream success, hcclStreamNum =%ld", hccl_stream_num); return SUCCESS; } -void HcclTaskInfo::ReuseStream(int64_t stream_num, DavinciModel *davinci_model) { - GELOGI("Start to reuse %ld follow stream.", stream_num); - int64_t index = 0; - for (int64_t i = 0; i < stream_num; i++) { - hccl_stream_list_.emplace_back(davinci_model->GetHcclFolowStream().at(index).first); - int64_t remain_cap = davinci_model->GetHcclFolowStream().at(index).second - 1; - davinci_model->ReuseHcclFollowStream(remain_cap, index); - } -} - -Status HcclTaskInfo::CreateStream(int64_t stream_num, DavinciModel *davinci_model) { +Status HcclTaskInfo::CreateStream(int64_t stream_num, DavinciModel *davinci_model, int64_t main_stream_id) { GELOGI("Start to create %ld hccl stream.", stream_num); for (int64_t i = 0; i < stream_num; ++i) { rtStream_t stream = nullptr; @@ -189,8 +181,7 @@ Status HcclTaskInfo::CreateStream(int64_t stream_num, DavinciModel *davinci_mode return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGD("hccl_stream addr is=%p", stream); - int64_t remain_cap = max_node_of_hccl_stream_ - 1; - davinci_model->CreateHcclFollowStream(stream, remain_cap); + davinci_model->SaveHcclFollowStream(main_stream_id, stream); hccl_stream_list_.emplace_back(stream); davinci_model->PushHcclStream(stream); diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h index cc3109f4..d8456834 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h @@ -60,9 +60,7 @@ class HcclTaskInfo : public TaskInfo { void GetPrivateDefByTaskDef(const domi::TaskDef &task); - void ReuseStream(int64_t stream_num, DavinciModel *davinci_model); - - ge::Status CreateStream(int64_t stream_num, DavinciModel *davinci_model); + ge::Status CreateStream(int64_t stream_num, DavinciModel *davinci_model, int64_t main_stream_id); Status SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciModel *davinci_model); @@ -77,7 +75,6 @@ class HcclTaskInfo : public TaskInfo { void *private_def_; uint32_t private_def_len_; static std::mutex hccl_follow_stream_mutex_; - static uint32_t max_node_of_hccl_stream_; vector kernel_hccl_infos_; vector input_data_addrs_; vector output_data_addrs_; diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc index da6d05ca..7c873c68 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc @@ -25,7 +25,6 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/l2_cache_optimize.h" #include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/model_utils.h" #include "runtime/kernel.h" @@ -92,7 +91,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci string session_graph_model_id; davinci_model_->GetUniqueId(op_desc_, session_graph_model_id); // get bin_file_key - const char *bin_file_key = DavinciModel::GetRegisterStub(op_desc_->GetName(), session_graph_model_id); + const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id); // new aicpu kernel(rtCpuKernelLaunch) no need to check function if (kernel_type_ == cce::ccKernelType::CCE_AI_CORE) { rtError_t rt_ret; @@ -395,7 +394,14 @@ Status KernelTaskInfo::Distribute() { "stubfunc:%p blockdim:%u stream:%p", call_skt, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); // l1 fusion enable and env flag open (kCloseSkt for skt debug) - if (call_skt && (env_flag != kCloseSkt)) { + bool open_dump = false; + auto all_dump_model = davinci_model_->GetDumpProperties().GetAllDumpModel(); + if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || + all_dump_model.find(davinci_model_->Name()) != all_dump_model.end() || + all_dump_model.find(davinci_model_->OmName()) != all_dump_model.end()) { + open_dump = true; + } + if (call_skt && (env_flag != kCloseSkt) && !open_dump) { GE_RETURN_IF_ERROR(SuperKernelDistribute()); } else { // call rtKernelLaunch for current task @@ -577,7 +583,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne // When inferencing, stub_func_ is different from dynamic-registration to runtime, and needs to be modified. string session_graph_model_id; davinci_model_->GetUniqueId(op_desc, session_graph_model_id); - const char *bin_file_key = DavinciModel::GetRegisterStub(op_desc->GetName(), session_graph_model_id); + const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc->GetName(), session_graph_model_id); rtError_t rt_ret = rtQueryFunctionRegistered(const_cast(bin_file_key)); if (rt_ret != RT_ERROR_NONE) { stub_func_ = const_cast(bin_file_key); @@ -634,7 +640,11 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne skt_dump_args_ = static_cast(args_) + offset; if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), op_desc->GetName())) { - dump_flag_ = RT_KERNEL_DUMPFLAG; + if (IsL1FusionOp(op_desc)) { + dump_flag_ = RT_FUSION_KERNEL_DUMPFLAG; + } else { + dump_flag_ = RT_KERNEL_DUMPFLAG; + } dump_args_ = static_cast(args_) + offset; } @@ -653,6 +663,25 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne return SUCCESS; } +bool KernelTaskInfo::IsL1FusionOp(const OpDescPtr &op_desc) { + std::vector input_memory_type; + (void)ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_MEM_TYPE_LIST, input_memory_type); + for (size_t i = 0; i < input_memory_type.size(); ++i) { + if (input_memory_type.at(i) == RT_MEMORY_L1) { + return true; + } + } + + std::vector output_memory_type; + (void)ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, output_memory_type); + for (size_t i = 0; i < output_memory_type.size(); ++i) { + if (output_memory_type.at(i) == RT_MEMORY_L1) { + return true; + } + } + return false; +} + Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::KernelDef &kernel_def) { GELOGI("Do InitAICPUCustomTask"); OpDescPtr op_desc = davinci_model_->GetOpByIndex(op_index); @@ -904,7 +933,11 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), op_desc->GetName())) { - dump_flag_ = RT_KERNEL_DUMPFLAG; + if (IsL1FusionOp(op_desc)) { + dump_flag_ = RT_FUSION_KERNEL_DUMPFLAG; + } else { + dump_flag_ = RT_KERNEL_DUMPFLAG; + } dump_args_ = static_cast(args_) + sizeof(aicpu::AicpuParamHead); } diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h index cc8edc07..8ada2082 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h @@ -127,6 +127,7 @@ class KernelTaskInfo : public TaskInfo { static void FreeRtMem(void **ptr); Status SuperKernelDistribute(); + bool IsL1FusionOp(const OpDescPtr &op_desc); // For super kernel Status SaveSKTDumpInfo(); diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc index 8cac9f82..1f542154 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc @@ -59,7 +59,12 @@ Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel // malloc args memory size_t args_size = sizeof(void *) * io_addrs.size(); - rtError_t rt_ret = rtMalloc(&args_, args_size + kAlignBytes, RT_MEMORY_HBM); + rtMemType_t memory_type = RT_MEMORY_HBM; + if (op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE)) { + memory_type = RT_MEMORY_TS_4G; + } + GELOGI("memory_type: %u", memory_type); + rtError_t rt_ret = rtMalloc(&args_, args_size + kAlignBytes, memory_type); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc index 1cc18a85..96247e7d 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc @@ -36,6 +36,12 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da count_ = memcpy_async.count(); kind_ = memcpy_async.kind(); dst_max_ = memcpy_async.dst_max(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(memcpy_async.op_index()); + if (op_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "Task op index:%u out of range", memcpy_async.op_index()); + return INTERNAL_ERROR; + } + if (davinci_model->IsKnownNode()) { src_ = reinterpret_cast(davinci_model_->GetCurrentArgsAddr(args_offset_)); dst_ = reinterpret_cast(reinterpret_cast(src_) + sizeof(void *)); @@ -49,9 +55,17 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da return ret; } - ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.dst(), dst_); - if (ret != SUCCESS) { - return ret; + // dst_ needs different address for different chips + if (op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE)) { + ret = AllocTsMemoryForMemcpy(op_desc, davinci_model); + if (ret != SUCCESS) { + return ret; + } + } else { + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.dst(), dst_); + if (ret != SUCCESS) { + return ret; + } } GELOGI("MemcpyAsyncTaskInfo Init Success, logic[0x%lx, 0x%lx], src:%p, dst:%p, max:%lu, count:%lu", @@ -108,5 +122,33 @@ Status MemcpyAsyncTaskInfo::UpdateArgs() { return SUCCESS; } +Status MemcpyAsyncTaskInfo::AllocTsMemoryForMemcpy(const OpDescPtr &op_desc, DavinciModel *davinci_model) { + int64_t size = 0; + auto tensor_desc = op_desc->GetOutputDescPtr(0); + if ((tensor_desc == nullptr) || (TensorUtils::GetTensorSizeInBytes(*tensor_desc, size) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "GetTensorSizeInBytes failed!"); + return FAILED; + } + + rtError_t rt_ret = rtMalloc(&memory_4g_, size, RT_MEMORY_TS_4G); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMalloc failed, ret: 0x%X", rt_ret); + return FAILED; + } + + // map save the opdesc's offset and special address, for update the streamSwitchN's input address + std::map memcpy_4g_offset_addr; + vector offsets = op_desc->GetOutputOffset(); + if (offsets.empty()) { + GELOGE(FAILED, "GetOutputOffset failed!"); + return FAILED; + } + memcpy_4g_offset_addr.insert(std::pair(offsets[0], memory_4g_)); + davinci_model->SetMemcpyOffsetAndAddr(memcpy_4g_offset_addr); + + dst_ = reinterpret_cast(memory_4g_); + return SUCCESS; +} + REGISTER_TASK_INFO(RT_MODEL_TASK_MEMCPY_ASYNC, MemcpyAsyncTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h index c3daa862..9436529d 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h @@ -18,15 +18,24 @@ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ #include "graph/load/new_model_manager/task_info/task_info.h" +#include "graph/op_desc.h" namespace ge { class MemcpyAsyncTaskInfo : public TaskInfo { public: - MemcpyAsyncTaskInfo() : dst_(nullptr), dst_max_(0), src_(nullptr), count_(0), kind_(0) {} + MemcpyAsyncTaskInfo() : dst_(nullptr), dst_max_(0), src_(nullptr), count_(0), kind_(0), memory_4g_(nullptr) {} ~MemcpyAsyncTaskInfo() override { src_ = nullptr; dst_ = nullptr; + + if (memory_4g_ != nullptr) { + rtError_t ret = rtFree(memory_4g_); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); + } + memory_4g_ = nullptr; + } } Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; @@ -38,6 +47,7 @@ class MemcpyAsyncTaskInfo : public TaskInfo { Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; private: + Status AllocTsMemoryForMemcpy(const OpDescPtr &op_desc, DavinciModel *davinci_model); uint8_t *dst_; uint64_t dst_max_; uint8_t *src_; @@ -46,6 +56,7 @@ class MemcpyAsyncTaskInfo : public TaskInfo { DavinciModel *davinci_model_ = nullptr; uint32_t args_offset_ = 0; domi::MemcpyAsyncDef memcpy_async; + void *memory_4g_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc index d134dfdd..d95aefac 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc @@ -66,16 +66,13 @@ Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * GELOGE(FAILED, "Get true stream ptr of switchN op failed."); return FAILED; } - if (davinci_model->IsKnownNode()) { - input_ptr_ = davinci_model->GetCurrentFixedAddr(args_offset_); - } else { - auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (input_data_addr.empty()) { - GELOGE(FAILED, "Input data addr is nullptr."); - return FAILED; - } - input_ptr_ = input_data_addr[0]; + + // update StreamSwitchN's input_ptr_ + Status ret = InputPtrUpdate(op_desc, davinci_model); + if (ret != SUCCESS) { + return ret; } + davinci_model->DisableZeroCopy(input_ptr_); GELOGI("StreamSwitchNTaskInfo Init Success, inputSize:%u, elementSize:%d, trueStreamID:%ld.", input_size_, element_size_, op_desc->GetStreamId()); @@ -154,5 +151,36 @@ Status StreamSwitchNTaskInfo::CalculateArgs(const domi::TaskDef &task_def, Davin GELOGI("Calculate stream switchn task args , tensor_size %ld, args_offset %ld", tensor_size, args_offset_); return SUCCESS; } + +Status StreamSwitchNTaskInfo::InputPtrUpdate(const OpDescPtr &op_desc, DavinciModel *davinci_model) { + bool is_4g_mem = false; + const map memcpy_4g_offset_addr = davinci_model->GetMemcpyOffsetAndAddr(); + vector input_offset = op_desc->GetInputOffset(); + if (input_offset.empty()) { + GELOGE(FAILED, "Get StreamSwitchN's input offset failed."); + return FAILED; + } + + auto iter = memcpy_4g_offset_addr.find(input_offset[0]); + if (iter != memcpy_4g_offset_addr.end()) { + input_ptr_ = iter->second; + is_4g_mem = true; + } + + if (is_4g_mem == false) { + if (davinci_model->IsKnownNode()) { + input_ptr_ = davinci_model->GetCurrentFixedAddr(args_offset_); + } else { + auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + if (input_data_addr.empty()) { + return FAILED; + } + input_ptr_ = input_data_addr[0]; + } + } + + GELOGI("StreamSwitchN's input_ptr is %p, is_4g_mem: %d", input_ptr_, is_4g_mem); + return SUCCESS; +} REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_SWITCH_N, StreamSwitchNTaskInfo); -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h index 1a96243a..5a73eb1a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h @@ -42,6 +42,7 @@ class StreamSwitchNTaskInfo : public TaskInfo { private: Status GetTrueStreamPtr(const OpDescPtr &op_desc, DavinciModel *davinci_model); + Status InputPtrUpdate(const OpDescPtr &op_desc, DavinciModel *davinci_model); void *input_ptr_; uint32_t input_size_; void *value_ptr_; diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.cc b/src/ge/graph/load/new_model_manager/zero_copy_task.cc index 5b595d76..00920aad 100644 --- a/src/ge/graph/load/new_model_manager/zero_copy_task.cc +++ b/src/ge/graph/load/new_model_manager/zero_copy_task.cc @@ -143,10 +143,11 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const ma /** * @ingroup ge * @brief Update task param to device. + * @param [in] async_mode: true for asychronous mode. * @param [in] stream: Stream for asychronous update. * @return: 0 SUCCESS / others FAILED */ -Status ZeroCopyTask::DistributeParam(rtStream_t stream) { +Status ZeroCopyTask::DistributeParam(bool async_mode, rtStream_t stream) { if (!is_updated_) { return SUCCESS; } @@ -154,7 +155,7 @@ Status ZeroCopyTask::DistributeParam(rtStream_t stream) { is_updated_ = false; GE_CHECK_NOTNULL(args_addr_); rtError_t rt_err = RT_ERROR_NONE; - if (stream != nullptr) { + if (async_mode) { rt_err = rtMemcpyAsync(args_addr_, args_size_, args_info_.data(), args_info_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream); } else { diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.h b/src/ge/graph/load/new_model_manager/zero_copy_task.h index d2a91ce7..799844a5 100644 --- a/src/ge/graph/load/new_model_manager/zero_copy_task.h +++ b/src/ge/graph/load/new_model_manager/zero_copy_task.h @@ -77,10 +77,11 @@ class ZeroCopyTask { /** * @ingroup ge * @brief Update task param to device. + * @param [in] async_mode: true for asychronous mode. * @param [in] stream: Stream for asychronous update. * @return: 0 SUCCESS / others FAILED */ - ge::Status DistributeParam(rtStream_t stream); + ge::Status DistributeParam(bool async_mode, rtStream_t stream); protected: bool CheckDynamicBatch(const map> &batch_addrs, const string &batch_label, uintptr_t addr); @@ -97,4 +98,4 @@ class ZeroCopyTask { map> task_addr_offset_; }; } // namespace ge -#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_TASK_H_ \ No newline at end of file +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_TASK_H_ diff --git a/src/ge/graph/manager/graph_manager.cc b/src/ge/graph/manager/graph_manager.cc index bdf2143c..582b206a 100644 --- a/src/ge/graph/manager/graph_manager.cc +++ b/src/ge/graph/manager/graph_manager.cc @@ -91,7 +91,13 @@ #include "graph/passes/variable_ref_delete_op_pass.h" #include "graph/passes/variable_ref_useless_control_out_delete_pass.h" #include "graph/passes/end_of_sequence_add_control_pass.h" +#include "graph/passes/subexpression_migration_pass.h" +#include "graph/passes/unused_args_clean_pass.h" +#include "graph/passes/global_step_insert_pass.h" #include "graph/utils/tensor_adapter.h" +#include "graph/utils/type_utils.h" +#include "graph/graph_util.h" +#include "graph/types.h" #include "inc/pass_manager.h" #include "init/gelib.h" @@ -102,6 +108,8 @@ const char *const kNetOutput = "NetOutput"; const char *const kVariable = "Variable"; const char *const kSend = "Send"; const char *const kRecv = "Recv"; +const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar"; +const char *const kCheckPointGraph = "checkpoint_graph"; bool IsTailingOptimization() { string is_tailing_optimization_option; @@ -380,6 +388,11 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id) { @@ -427,6 +455,8 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: ret = IncreBuild(graph_node, ge_model); if (ret != SUCCESS) { ret = PreRun(graph_node, inputs, ge_root_model, session_id); + // release rts generate context + RtContextUtil::GetInstance().DestroyRtContexts(session_id); if (ret != SUCCESS) { GELOGE(ret, "PreRun Failed."); return ret; @@ -1388,6 +1418,9 @@ bool GraphManager::CheckNetOutputForCheckpointGraph(NodePtr &node) { } bool GraphManager::CheckVariableForCheckpointGraph(NodePtr &node) { + if (node->GetOpDesc()->HasAttr(kCheckPointForGetVar)) { + return false; + } auto out = node->GetOutDataAnchor(0); if (out == nullptr) { GELOGE(GE_GRAPH_PARAM_NULLPTR, "out is nullptr."); @@ -1573,48 +1606,6 @@ Status GraphManager::RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph) { return SUCCESS; } -Status GraphManager::NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph) { - GELOGD("NewOptimizeAfterMergeSubGraph in"); - - GEPass ge_passes(compute_graph); - NamesToPass names_to_passes; - ConstantFoldingPass constant_folding_pass; - names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); - GE_TIMESTAMP_START(names_to_passes); - auto ret = ge_passes.Run(names_to_passes); - GE_TIMESTAMP_END(names_to_passes, "GraphManager::ge_passes"); - if (ret != SUCCESS) { - GELOGE(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); - return ret; - } - - ret = RemoveIsolatedConst(compute_graph); - if (ret != SUCCESS) { - GELOGE(ret, "Remove isolated Constant failed, ret:%d.", ret); - return ret; - } - - PassManager passes; - GE_CHK_STATUS_RET(passes.AddPass("MultiBatchPass", new (std::nothrow) MultiBatchPass)); - GE_CHK_STATUS_RET(passes.AddPass("CompileNodesPass", new (std::nothrow) CompileNodesPass)); - GE_CHK_STATUS_RET(passes.AddPass("AtomicAddrCleanPass", new (std::nothrow) AtomicAddrCleanPass)); - - GE_TIMESTAMP_START(passes); - ret = passes.Run(compute_graph); - GE_TIMESTAMP_END(passes, "GraphManager::passes"); - if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run passes optimize for OptimizeAfterMergeSubGraph failed"); - return ret; - } - - ret = compute_graph->TopologicalSorting(); - if (ret != SUCCESS) { - GELOGE(ret, "Graph topological sort failed, ret:%d.", ret); - return ret; - } - return SUCCESS; -} - Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { string options = "default"; if (GetContext().GetOption("ge.exec.variable_acc", options) != SUCCESS) { @@ -1721,10 +1712,17 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { graph_pass.AddPass("OptimizeStage1_3::SwitchToStreamSwitchPass", new (std::nothrow) SwitchToStreamSwitchPass)) GE_CHK_STATUS_RET( graph_pass.AddPass("OptimizeStage1_3::AttachStreamLabelPass", new (std::nothrow) AttachStreamLabelPass)) + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MultiBatchPass", new (std::nothrow) MultiBatchPass(true))) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::IteratorOpPass", new (std::nothrow) IteratorOpPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", new (std::nothrow) VariableRefUselessControlOutDeletePass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) + if (options_.train_graph_flag) { + // Priority: The GlobalStepInsertPass should work before graph partitioner. + // Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory + GE_CHK_STATUS_RET( + graph_pass.AddPass("OptimizeStage1_3::GlobalStepInsertPass", new (std::nothrow) GlobalStepInsertPass)) + } GE_TIMESTAMP_START(graph_pass); ret = graph_pass.Run(compute_graph); GE_TIMESTAMP_END(graph_pass, "GraphManager::OptimizeStage1_3"); @@ -1787,11 +1785,8 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { PassManager pass_for_control_attr_optimize; if (options_.train_graph_flag) { - const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); - if (unknown_shape_skip == nullptr) { - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::FlowCtrlPass", - new (std::nothrow) FlowCtrlPass)) - } + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::FlowCtrlPass", + new (std::nothrow) FlowCtrlPass)) } GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::MultiBatchPass", @@ -1821,14 +1816,10 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { pass_for_control_attr_optimize.AddPass("OptimizeStage2::AfterMergePasses::" "EndOfSequenceAddControlPass", new (std::nothrow) EndOfSequenceAddControlPass)) - - const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); - if (unknown_shape_skip == nullptr) { - // SubgraphPass solves memory_assign_conflicts by insert MemcpyAsync node, which depends on multi attrs and - // graph-structure. So try not to add new pass after SubgraphPass. - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::SubgraphPass", - new (std::nothrow) SubgraphPass)) - } + // SubgraphPass solves memory_assign_conflicts by insert MemcpyAsync node, which depends on multi attrs and + // graph-structure. So try not to add new pass after SubgraphPass. + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::SubgraphPass", + new (std::nothrow) SubgraphPass)) // AttachStreamLabelPass modifies attr without changing structure of compute_graph GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::AttachStreamLabelPass", new (std::nothrow) AttachStreamLabelPass)) @@ -1870,120 +1861,6 @@ void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_gr } } } -Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph) { - GELOGI("Start optimize after merge sub graph."); - - GEPass ge_passes_for_shape(compute_graph); - NamesToPass names_to_passes_for_shape; - CastRemovePass cast_remove_pass; - names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); - TransposeTransDataPass transpose_transdata_pass; - names_to_passes_for_shape.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); - GE_TIMESTAMP_START(ge_passes_for_shape); - Status ret = ge_passes_for_shape.Run(names_to_passes_for_shape); - GE_TIMESTAMP_END(ge_passes_for_shape, "GraphManager::GePassesForShape"); - GE_CHK_STATUS_RET(ret, "Run ge_passes_for_shape optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); - - string options = "default"; - if (GetContext().GetOption("ge.exec.variable_acc", options) != SUCCESS) { - GELOGI("get ge.exec.variable_acc failed. set default value."); - } - PassManager after_merge_passes; - GE_CHK_STATUS_RET(after_merge_passes.AddPass("PermutePass", new (std::nothrow) PermutePass)); - GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); GE_CHK_STATUS_RET( - after_merge_passes.AddPass("VariableOpPass", new (std::nothrow) VariableOpPass(&var_acc_ctrl_)))); - ret = after_merge_passes.Run(compute_graph); - if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); - return ret; - } - - // reshape remove + symmetry_elimination_pass to replace transop depth fusion pass - GEPass ge_passes_symmetry(compute_graph); - NamesToPass names_to_passes_for_symmetry; - ReshapeRemovePass reshape_remove_pass; - names_to_passes_for_symmetry.emplace_back("ReshapeRemovePass", &reshape_remove_pass); - TransOpSymmetryEliminationPass symmetry_elimination_pass; - names_to_passes_for_symmetry.emplace_back("TransOpSymmetryEliminationPass", &symmetry_elimination_pass); - ret = ge_passes_symmetry.Run(names_to_passes_for_symmetry); - GE_CHK_STATUS_RET(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); - - PassManager after_merge_fusion_passes; - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass("TransOpWithoutReshapeFusionPass", - new (std::nothrow) TransOpWithoutReshapeFusionPass)); - GE_CHK_STATUS_RET( - after_merge_fusion_passes.AddPass("TransOpBreadthFusionPass", new (std::nothrow) TransOpBreadthFusionPass)); - GE_CHK_STATUS_RET( - after_merge_fusion_passes.AddPass("VariableRefDeleteOpPass", new (std::nothrow) VariableRefDeleteOpPass)); - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass("SameTransdataBreadthFusionPass", - new (std::nothrow) SameTransdataBreadthFusionPass)); - GE_CHK_STATUS_RET( - after_merge_fusion_passes.AddPass("MarkGraphUnknownStatusPass", new (std::nothrow) MarkGraphUnknownStatusPass)); - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass("AtomicAddrCleanPass", new (std::nothrow) AtomicAddrCleanPass)); - GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass( - "LinkGenMaskNodesPass", new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))); - GE_TIMESTAMP_START(after_merge_fusion_passes); - ret = after_merge_fusion_passes.Run(compute_graph); - GE_TIMESTAMP_END(after_merge_fusion_passes, "GraphManager::AfterMergePasses"); - if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); - return ret; - } - - // add variable attr for hccl broadcast,need to be removed after variable pass online - for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { - if (node->GetOpDesc()->GetType() != VARIABLE) { - continue; - } - - if (IsBroadCastOpData(node)) { - AdjustBroadCastOpData(node); - } - if (IsAssignOpData(node)) { - AdjustAssignOpData(node); - } - } - - GEPass ge_passes(compute_graph); - NamesToPass names_to_passes; - TransOpNearbyAllreduceFusionPass trans_op_nearby_allreduce_fusion_pass; - names_to_passes.emplace_back("TransOpNearbyAllreduceFusionPass", &trans_op_nearby_allreduce_fusion_pass); - names_to_passes_for_shape.emplace_back("ReshapeRemovePass", &reshape_remove_pass); - ConstantFoldingPass constant_folding_pass; - names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); - DimensionAdjustPass dimension_adjust_pass; - names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); - CondRemovePass condition_remove_pass; - names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); - GE_TIMESTAMP_START(names_to_passes); - ret = ge_passes.Run(names_to_passes); - GE_TIMESTAMP_END(names_to_passes, "GraphManager::MergedGraphNameToPasses"); - GE_CHK_STATUS_RET(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); - - ret = RemoveIsolatedConst(compute_graph); - GE_CHK_STATUS_RET(ret, "Remove isolated Constant failed, ret:%d.", ret); - - PassManager pass_for_optimize; - const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); - if (unknown_shape_skip == nullptr) { - GE_CHK_STATUS_RET(pass_for_optimize.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass)); - } - GE_CHK_STATUS_RET(pass_for_optimize.AddPass("MultiBatchPass", new (std::nothrow) MultiBatchPass)); - GE_CHK_STATUS_RET(pass_for_optimize.AddPass("CompileNodesPass", new (std::nothrow) CompileNodesPass)); - GE_TIMESTAMP_START(pass_for_optimize); - ret = pass_for_optimize.Run(compute_graph); - GE_TIMESTAMP_END(pass_for_optimize, "GraphManager::OptimizePass"); - if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run optimize pass failed"); - return ret; - } - - ret = compute_graph->TopologicalSorting(); - GE_CHK_STATUS_RET(ret, "Graph topological sort failed, ret:%d.", ret); - - GELOGI("End optimize after merge sub graph."); - return SUCCESS; -} Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); @@ -2185,6 +2062,19 @@ Status GraphManager::IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_m return FAILED; } +void GraphManager::ConstructGeInput(std::vector &ge_inputs, PreRunArgs &args) { + for (auto const &input : args.input_tensor) { + std::vector input_dims; + std::transform(input.dims.begin(), input.dims.end(), std::back_inserter(input_dims), + [](int64_t x) -> int64_t { return x; }); + GeShape input_shape(input_dims); + GeTensorDesc input_tensor_desc; + input_tensor_desc.SetShape(input_shape); + input_tensor_desc.SetDataType(static_cast(input.data_type)); + ge_inputs.emplace_back(input_tensor_desc); + } +} + void GraphManager::PreRunThread(GraphManager *graph_manager) { if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { GELOGW("Set thread name failed."); @@ -2198,16 +2088,8 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { GetThreadLocalContext() = args.context; GELOGI("A new loop start."); std::vector ge_inputs; - for (auto const &input : args.input_tensor) { - std::vector input_dims; - std::transform(input.dims.begin(), input.dims.end(), std::back_inserter(input_dims), - [](int64_t x) -> int64_t { return x; }); - GeShape input_shape(input_dims); - GeTensorDesc input_tensor_desc; - input_tensor_desc.SetShape(input_shape); - input_tensor_desc.SetDataType(static_cast(input.data_type)); - ge_inputs.emplace_back(input_tensor_desc); - } + ConstructGeInput(ge_inputs, args); + // find graph GraphNodePtr graph_node = nullptr; Status ret = graph_manager->GetGraphNode(args.graph_id, graph_node); @@ -2229,14 +2111,11 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { graph_node->SetRunFlag(true); ComputeGraphPtr compute_graph_tmp = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); - - if (graph_manager->GetTrainFlag()) { - if (compute_graph_tmp == nullptr) { - ReturnError(graph_manager, args.callback, GE_GRAPH_GRAPH_NODE_NULL, - "[RunGraph] compute_graph_tmp is NULL, graph id = %u."); - graph_node->Unlock(); - return; - } + if (compute_graph_tmp == nullptr) { + ReturnError(graph_manager, args.callback, GE_GRAPH_GRAPH_NODE_NULL, + "[RunGraph] compute_graph_tmp is NULL, graph id = %u."); + graph_node->Unlock(); + return; } // when set incre build, save cache helper. graph_manager->AddModelCacheHelperToMap(args.graph_id, args.session_id, compute_graph_tmp); @@ -2266,11 +2145,19 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { GeModelPtr ge_model = nullptr; if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); + // release rts generate context + RtContextUtil::GetInstance().DestroyRtContexts(args.session_id); if (ret != SUCCESS) { graph_node->SetRunFlag(false); - ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); - graph_node->Unlock(); - return; + if (!std::getenv("AnalyzeMode")) { + ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); + graph_node->Unlock(); + return; + } else { + ReturnError(graph_manager, graph_node, args.callback, ret, "PreRun Failed, keep geop continue!"); + graph_node->Unlock(); + continue; + } } } graph_node->SetBuildFlag(true); @@ -2350,13 +2237,74 @@ void GraphManager::ReturnError(GraphManager *graph_manager, RunAsyncCallback cal if (graph_manager == nullptr) { return; } - - GELOGE(ret, "%s.", log.c_str()); StopQueue(graph_manager); + GELOGE(ret, "%s.", log.c_str()); std::vector outputs; callback(ret, outputs); } +void GraphManager::ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback, + Status ret, const string &log) { + std::vector outputs; + auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); + if (graph_manager == nullptr || compute_graph == nullptr) { + GELOGE(GRAPH_FAILED, "[Analyze Mode] compute graph is null!"); + callback(GRAPH_FAILED, outputs); + return; + } + + for (const auto &node : compute_graph->GetAllNodes()) { + if (node->GetType() != "NetOutput") { + continue; + } + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); i++) { + auto input_desc = node->GetOpDesc()->MutableInputDesc(i); + ge::OutputTensorInfo tensor; + tensor.dims = input_desc->GetShape().GetDims(); + tensor.data_type = static_cast(input_desc->GetDataType()); + int64_t len = 1; + if (input_desc->GetShape().GetDims() != std::vector({})) { + len = input_desc->GetShape().GetShapeSize(); + } + if (len < 0) { + GELOGE(GRAPH_FAILED, "Analyze Mode does not support GEOP output unknown shape!"); + callback(GRAPH_FAILED, outputs); + return; + } else if (len == 0) { + GELOGI("getted shape size is 0.Do process as empty tensor!"); + len = 1; + } + auto size = GetSizeByDataType(input_desc->GetDataType()); + if (size <= 0) { + GELOGE(PARAM_INVALID, "Failed to get cube size, the data type %s is invalid", + ge::TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); + callback(GRAPH_FAILED, outputs); + return; + } + if (CheckInt64MulOverflow(len, static_cast(size)) != true) { + GELOGE(MEMALLOC_FAILED, "int64 multiply happens overflow! a:%ld b:%d", len, size); + callback(GRAPH_FAILED, outputs); + return; + } + tensor.length = len * size; + auto pbuff = new (std::nothrow) uint8_t[tensor.length]; + if (!pbuff) { + GELOGE(MEMALLOC_FAILED, "new buff failed!"); + callback(GRAPH_FAILED, outputs); + return; + } + // To avoid global step too small and can not stop, totally set a bigger value + for (int64_t i = 0; i < tensor.length; i++) { + *(pbuff + i) = 0x7F; // here stands for a positive max value + } + tensor.data.reset(pbuff); + outputs.emplace_back(std::move(tensor)); + } + } + callback(SUCCESS, outputs); + return; +} + bool GraphManager::IsGraphNeedRebuild(uint32_t graph_id) { // find graph GraphNodePtr graph_node = nullptr; @@ -2479,4 +2427,99 @@ Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &comp graph_node->SetGeRootModel(ge_root_model); return SUCCESS; } + +Status GraphManager::GenCheckPointGraph(const std::map &all_variables, Graph &graph) { + ge::ComputeGraphPtr compute_graph = MakeShared(kCheckPointGraph); + GE_CHECK_NOTNULL(compute_graph); + OpDescPtr save_desc = MakeShared(compute_graph->GetName() + "_" + kSave, kSave); + GE_CHECK_NOTNULL(save_desc); + uint32_t save_index = 0; + for (auto iter = all_variables.begin(); iter != all_variables.end(); ++iter) { + GE_CHK_GRAPH_STATUS_RET(save_desc->AddInputDesc(save_index, iter->second)); + save_index++; + } + NodePtr save_node = compute_graph->AddNode(save_desc); + + uint32_t index = 0; + for (auto iter = all_variables.begin(); iter != all_variables.end(); ++iter) { + OpDescPtr var_desc = MakeShared(iter->first, VARIABLE); + GE_CHECK_NOTNULL(var_desc); + if (!AttrUtils::SetBool(var_desc, kCheckPointForGetVar, true)) { + GELOGW("Set check point graph attr failed."); + } + GE_CHK_GRAPH_STATUS_RET(var_desc->AddOutputDesc(iter->second)); + NodePtr var_node = compute_graph->AddNode(var_desc); + GE_CHK_STATUS(GraphUtils::AddEdge(var_node->GetOutDataAnchor(0), save_node->GetInDataAnchor(index)), + "Add edge[%s->%s] fail.", var_node->GetName().c_str(), save_node->GetName().c_str()); + index++; + } + compute_graph->Dump(); + graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + return SUCCESS; +} + +Status GraphManager::SaveVariables(const Graph &graph, const std::vector &var_names, + const std::vector &outputs, std::vector &var_values) { + map var_results; + GE_CHK_STATUS_RET(SaveCheckPointResult(graph, outputs, var_results), "Save check point result failed."); + if (!var_names.empty()) { + for (const auto &var_name : var_names) { + if (var_results.count(var_name) == 0) { + GELOGE(FAILED, "Fetch var[%s] value failed.", var_name.c_str()); + return FAILED; + } else { + var_values.emplace_back(var_results[var_name]); + } + } + } else { + for (auto iter = var_results.begin(); iter != var_results.end(); ++iter) { + var_values.emplace_back(iter->second); + } + } + return SUCCESS; +} + +Status GraphManager::SaveCheckPointResult(const Graph &graph, const std::vector &outputs, + map &var_results) { + auto compute_graph = GraphUtils::GetComputeGraph(graph); + NodePtr netoutput_node = nullptr; + for (const auto &node : compute_graph->GetAllNodes()) { + if (node->GetType() == NETOUTPUT) { + netoutput_node = node; + break; + } + } + GE_CHECK_NOTNULL(netoutput_node); + for (const auto &in : netoutput_node->GetAllInDataAnchors()) { + auto out_anchor = in->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(out_anchor); + auto peer_node = out_anchor->GetOwnerNode(); + while (peer_node->GetType() != VARIABLE) { + if (peer_node->GetAllInDataAnchors().size() != 1) { + GELOGE(FAILED, "peer_node [%s] has more than 1 input in checkpoint Graph.", peer_node->GetName().c_str()); + return FAILED; + } + auto peer_node_in_anchor = peer_node->GetAllInDataAnchors().at(0); + auto peer_node_out_anchor = peer_node_in_anchor->GetPeerOutAnchor(); + if (peer_node_out_anchor != nullptr) { + peer_node = peer_node_out_anchor->GetOwnerNode(); + if (peer_node->GetType() == VARIABLE) { + break; + } + } + } + if (peer_node->GetType() != VARIABLE) { + GELOGE(FAILED, " peer_node %s is not variable in checkpoint Graph.", peer_node->GetName().c_str()); + return FAILED; + } + auto var_name = peer_node->GetName(); + GELOGI("[GraphManager] SaveVariables, varName is %s.", var_name.c_str()); + if (in->GetIdx() >= static_cast(outputs.size())) { + GELOGE(FAILED, "variable index[%d] out of range[%zu].", in->GetIdx(), outputs.size()); + return FAILED; + } + var_results.emplace(var_name, outputs.at(in->GetIdx())); + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/manager/graph_manager.h b/src/ge/graph/manager/graph_manager.h index 681efac8..6dc83120 100644 --- a/src/ge/graph/manager/graph_manager.h +++ b/src/ge/graph/manager/graph_manager.h @@ -159,6 +159,13 @@ class GraphManager { void SetOptionsRunGraphFlag(bool run_graph_flag); + Status GenCheckPointGraph(const std::map &all_variables, Graph &graph); + + Status SaveVariables(const Graph &graph, const std::vector &var_names, + const std::vector &outputs, std::vector &var_values); + + Status SaveCheckPointResult(const Graph &graph, const std::vector &outputs, map &var_results); + private: struct PreRunArgs { GraphId graph_id; @@ -267,9 +274,8 @@ class GraphManager { Status OptimizeStage1(ComputeGraphPtr &compute_graph); Status OptimizeStage2(ComputeGraphPtr &compute_graph); - Status OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph); - Status NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph); + Status SubexpressionMigration(ComputeGraphPtr &compute_graph); Status LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); @@ -288,10 +294,13 @@ class GraphManager { Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); void RemoveModelCacheHelper(const GraphId &graph_id); + static void ConstructGeInput(std::vector &ge_inputs, PreRunArgs &args); static void PreRunThread(GraphManager *graph_manager); static void RunThread(GraphManager *graph_manager); static void StopQueue(GraphManager *graph_manager); static void ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log); + static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback, Status ret, + const string &log); void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); diff --git a/src/ge/graph/manager/graph_var_manager.cc b/src/ge/graph/manager/graph_var_manager.cc index 7ca0224b..8633e361 100644 --- a/src/ge/graph/manager/graph_var_manager.cc +++ b/src/ge/graph/manager/graph_var_manager.cc @@ -855,6 +855,32 @@ void VarManager::RemoveAllocatedGraphId(const std::string &var_name) { var_resource_->RemoveAllocatedGraphId(var_name); } +Status VarManager::GetAllVariables(std::map &all_variables) { + std::lock_guard lock(mutex_); + if (var_resource_ == nullptr) { + GELOGW("VarManager has not been inited."); + return INTERNAL_ERROR; + } + auto new_variable_desc = var_resource_->GetAllVarDesc(); + if (new_variable_desc.size() == 0) { + GELOGW("VarManager don't have variables."); + return INTERNAL_ERROR; + } + + for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) { + auto trans_road = var_resource_->GetTransRoad(iter->first); + if (trans_road == nullptr || trans_road->empty()) { + GELOGI("The variable %s does not have any trans road", iter->first.c_str()); + all_variables[iter->first] = iter->second; + continue; + } + // get origin trans info : the first trans node info + auto origin_trans_node_info = trans_road->at(0); + all_variables[iter->first] = origin_trans_node_info.input; + } + return SUCCESS; +} + VarManagerPool::~VarManagerPool() { Destory(); } VarManagerPool &VarManagerPool::Instance() { @@ -897,4 +923,22 @@ VarManager *VarManagerPool::GetVarManager(uint64_t session_id) { var_manager_map_[session_id] = var_manager; return var_manager; } + +void VarManagerPool::RemoveVarManager(uint64_t session_id) { + VarManager *var_manager = nullptr; + { + std::lock_guard lock(var_manager_mutex_); + auto it = var_manager_map_.find(session_id); + if (it != var_manager_map_.end()) { + var_manager = it->second; + var_manager_map_.erase(it); + } + } + + if (var_manager != nullptr) { + var_manager->Destory(); + delete var_manager; + var_manager = nullptr; + } +} } // namespace ge diff --git a/src/ge/graph/manager/graph_var_manager.h b/src/ge/graph/manager/graph_var_manager.h index 2142d906..4a038f13 100644 --- a/src/ge/graph/manager/graph_var_manager.h +++ b/src/ge/graph/manager/graph_var_manager.h @@ -157,6 +157,8 @@ class VarResource { bool IsVarAddr(const int64_t &offset); + std::unordered_map GetAllVarDesc() const { return cur_var_tensor_desc_map_; } + private: std::string VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); @@ -276,6 +278,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type); + Status GetAllVariables(std::map &all_variables); + private: uint32_t version_; uint64_t session_id_; @@ -300,6 +304,8 @@ class VarManagerPool { VarManager *GetVarManager(uint64_t session_id); + void RemoveVarManager(uint64_t session_id); + void Destory() noexcept; ge::Status Init() const; diff --git a/src/ge/graph/manager/rdma_pool_allocator.cc b/src/ge/graph/manager/rdma_pool_allocator.cc index 1daeafb8..1ff77e92 100644 --- a/src/ge/graph/manager/rdma_pool_allocator.cc +++ b/src/ge/graph/manager/rdma_pool_allocator.cc @@ -16,7 +16,6 @@ #include "graph/manager/rdma_pool_allocator.h" #include "framework/common/debug/ge_log.h" -#include "graph/manager/graph_mem_allocator.h" namespace { const size_t kAlignedSize = 512; diff --git a/src/ge/graph/manager/rdma_pool_allocator.h b/src/ge/graph/manager/rdma_pool_allocator.h index 59d33916..e1da29a9 100644 --- a/src/ge/graph/manager/rdma_pool_allocator.h +++ b/src/ge/graph/manager/rdma_pool_allocator.h @@ -27,12 +27,11 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/manager/block_memory.h" +#include "graph/manager/graph_mem_allocator.h" #include "graph/node.h" #include "runtime/mem.h" namespace ge { -class MemoryAllocator; - class RdmaPoolAllocator { public: explicit RdmaPoolAllocator(rtMemType_t memory_type); diff --git a/src/ge/graph/manager/util/rt_context_util.cc b/src/ge/graph/manager/util/rt_context_util.cc index e6344539..63f217a9 100644 --- a/src/ge/graph/manager/util/rt_context_util.cc +++ b/src/ge/graph/manager/util/rt_context_util.cc @@ -28,6 +28,10 @@ void RtContextUtil::DestroyRtContexts(uint64_t session_id) { std::lock_guard lock(ctx_mutex_); auto &contexts = rt_contexts_[session_id]; DestroyRtContexts(session_id, contexts); + auto iter = rt_contexts_.find(session_id); + if (iter != rt_contexts_.end()) { + rt_contexts_.erase(iter); + } } void RtContextUtil::DestroyAllRtContexts() { diff --git a/src/ge/graph/optimize/graph_optimize.cc b/src/ge/graph/optimize/graph_optimize.cc index 09acae33..a8de6701 100644 --- a/src/ge/graph/optimize/graph_optimize.cc +++ b/src/ge/graph/optimize/graph_optimize.cc @@ -16,15 +16,10 @@ #include "graph/optimize/graph_optimize.h" -#include - -#include "framework/common/debug/ge_log.h" -#include "graph/anchor.h" +#include "graph/ge_context.h" #include "graph/passes/dimension_adjust_pass.h" -#include "graph/utils/graph_utils.h" #include "inc/pass_manager.h" #include "init/gelib.h" -#include "opskernel_manager/ops_kernel_manager.h" namespace { const char *const kVectorCore = "VectorCore"; @@ -156,6 +151,11 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { Status GraphOptimize::OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_graph) { GELOGD("OptimizeOriginalGraphJudgeInsert in"); + if (GetContext().GetHostExecFlag()) { + // graph exec on host, no need OptimizeOriginalGraph + return SUCCESS; + } + GE_CHECK_NOTNULL(compute_graph); Status ret = SUCCESS; std::shared_ptr instance_ptr = ge::GELib::GetInstance(); @@ -185,14 +185,12 @@ Status GraphOptimize::OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_ return ret; } -Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { - GELOGD("NewOptimizeOriginalGraph in"); +Status GraphOptimize::OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeOriginalGraph]: compute_graph is nullptr."); return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; } - Status ret = SUCCESS; std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeOriginalGraph failed."); @@ -200,25 +198,19 @@ Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { } auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); - GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", + GELOGI("optimize by opskernel in original graph optimize quantize phase. num of graph_optimizer is %zu.", graph_optimizer.size()); + Status ret = SUCCESS; string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; - GELOGD("[OptimizeOriginalGraph]: engine type will exclude: %s", exclude_core_Type.c_str()); + GELOGD("[OptimizeOriginalGraphForQuantize]: engine type will exclude: %s", exclude_core_Type.c_str()); if (graph_optimizer.size() != 0) { for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { - if (iter->first == exclude_core_Type) { + if (iter->first == exclude_core_Type || iter->second == nullptr) { continue; } - ret = (iter->second)->OptimizeOriginalGraph(*compute_graph); - if (ret != SUCCESS) { - GELOGE(ret, "[OptimizeOriginalGraph]: graph optimize failed, ret:%d", ret); - return ret; - } - - // call fe - ret = (iter->second)->OptimizeOriginalGraphJudgeInsert(*compute_graph); + ret = iter->second->OptimizeGraphPrepare(*compute_graph); if (ret != SUCCESS) { - GELOGE(ret, "[OptimizeOriginalGraphForInsert]: graph optimize failed, ret:%d", ret); + GELOGE(ret, "[OptimizeOriginalGraphForQuantize]: graph optimize failed, ret:%u", ret); return ret; } } @@ -226,32 +218,33 @@ Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { return ret; } -Status GraphOptimize::OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_graph) { +Status GraphOptimize::OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { - GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeOriginalGraph]: compute_graph is nullptr."); + GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeGraphBeforeBuildForRts]: compute_graph is nullptr."); return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; } std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeOriginalGraph failed."); + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeGraphBeforeBuildForRts failed."); return GE_CLI_GE_NOT_INITIALIZED; } auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); - GELOGI("optimize by opskernel in original graph optimize quantize phase. num of graph_optimizer is %zu.", + GELOGI("optimize by opskernel in graph optimize before build phase. num of graph_optimizer is %zu.", graph_optimizer.size()); Status ret = SUCCESS; string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; - GELOGD("[OptimizeOriginalGraphForQuantize]: engine type will exclude: %s", exclude_core_Type.c_str()); + GELOGI("[OptimizeGraphBeforeBuildForRts]: engine type will exclude: %s, core_type_: %s", exclude_core_Type.c_str(), + core_type_.c_str()); if (graph_optimizer.size() != 0) { for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { if (iter->first == exclude_core_Type || iter->second == nullptr) { continue; } - ret = iter->second->OptimizeGraphPrepare(*compute_graph); + ret = iter->second->OptimizeGraphBeforeBuild(*compute_graph); if (ret != SUCCESS) { - GELOGE(ret, "[OptimizeOriginalGraphForQuantize]: graph optimize failed, ret:%u", ret); + GELOGE(ret, "[OptimizeGraphBeforeBuildForRts]: graph optimize failed, ret:%u", ret); return ret; } } diff --git a/src/ge/graph/optimize/graph_optimize.h b/src/ge/graph/optimize/graph_optimize.h index f3eb2009..0bbeb0f7 100644 --- a/src/ge/graph/optimize/graph_optimize.h +++ b/src/ge/graph/optimize/graph_optimize.h @@ -49,12 +49,12 @@ class GraphOptimize { Status OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_graph); - // new original graph optimize - Status NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph); - // for fe prepare optimize in quantize scene Status OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_graph); + // for rts optimize before build to add attr and insert memcpy op + Status OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph); + // set options Status SetOptions(const GraphManagerOptions &options); @@ -62,8 +62,6 @@ class GraphOptimize { return summary_output_indexes_; } // lint !e1073 - void ClearSummaryOutputIndexes() { summary_output_indexes_.clear(); } - // handle summary node before preRun graph Status HandleSummaryOp(ComputeGraphPtr &compute_graph); diff --git a/src/ge/graph/optimize/mem_rw_conflict_optimize.cc b/src/ge/graph/optimize/mem_rw_conflict_optimize.cc index f75565ba..3ecc201a 100644 --- a/src/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/src/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -41,13 +41,14 @@ enum class OutputRWType { kWriteable, // ref output. Like Assign/ApplyMomentum kInvalidRWType }; + // input and output rw_type of one node. key is anchor_idx, value is rw_type struct NodeInputOutputRWType { map input_rw_type_map; map output_rw_type_map; }; // input and output rw_type of node in current graph -map node_rwtype_map_; +thread_local map node_rwtype_map_; /// /// @brief Convert input rw_type enum to string. For log print. @@ -110,21 +111,22 @@ OutputRWType GetSingleNodeOutputRWTypeByIndex(const Node &node, uint32_t index) /// @param rw_type_set /// @return /// -InputRWType GetInputRwTypeInConflict(std::set rw_type_set) { +InputRWType GetInputRwTypeInConflict(const std::set &rw_type_set) { // for input rw type calc int total_rw_type = 0; - for (auto rw : rw_type_set) { + for (const auto rw : rw_type_set) { total_rw_type += rw; } + switch (total_rw_type) { case 0: - return InputRWType::kReadOnly; + return InputRWType::kReadOnly; // all input rw type is readonly case 2: - return InputRWType::kScopeWriteable; + return InputRWType::kScopeWriteable; // readonly 2 scope_writeable case 3: - return InputRWType::kWriteable; + return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable case 5: - return InputRWType::kInvalidRWType; + return InputRWType::kInvalidRWType; // writeable 2 scope_writeable default: return InputRWType::kInvalidRWType; } @@ -145,12 +147,12 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { } auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); // 2. add input_desc & output_desc for new identity - Status ret = identity_opdesc->AddInputDesc(data_desc); + Status ret = identity_opdesc->AddInputDesc("x", data_desc); if (ret != SUCCESS) { GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); return nullptr; } - ret = identity_opdesc->AddOutputDesc(data_desc); + ret = identity_opdesc->AddOutputDesc("y", data_desc); if (ret != SUCCESS) { GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); return nullptr; @@ -227,7 +229,7 @@ InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) { return InputRWType::kWriteable; } } - // check if it is ref switch todo + // check if it is ref switch std::string type; if ((node.GetType() == FRAMEWORK_OP_TYPE) && (AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) && (type == REFSWITCH) && (index == 0)) { @@ -283,12 +285,66 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { return GetInputRwTypeInConflict(anchor_rw_type_set); } } +Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { + for (const auto &node : sub_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + std::set anchor_rw_type_set; + if (node->GetType() == DATA) { + // calc all input_rw_type of peer output , as input_rw_type of DATA. Index 0 is valid. + auto anchor_2_node_vec = NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, 0); + for (const auto anchor_2_node_pair : anchor_2_node_vec) { + auto input_rw_type = GetInputRWTypeByIndex(*anchor_2_node_pair.second, anchor_2_node_pair.first->GetIdx()); + GELOGD("Input rw type of Node %s %dth input anchor is %s", anchor_2_node_pair.second->GetName().c_str(), + anchor_2_node_pair.first->GetIdx(), InputRWTypeToSerialString(input_rw_type).c_str()); + anchor_rw_type_set.emplace(static_cast(input_rw_type)); + } + auto anchor_rw_type = GetInputRwTypeInConflict(anchor_rw_type_set); + GELOGD("Input rw type of Node %s is %s", node->GetName().c_str(), + InputRWTypeToSerialString(anchor_rw_type).c_str()); + map input_rw_type_map{std::make_pair(0, anchor_rw_type)}; + NodeInputOutputRWType data_rw_type{input_rw_type_map}; + node_rwtype_map_.emplace(std::make_pair(node->GetName(), data_rw_type)); + } + if (node->GetType() == NETOUTPUT) { + // calc all output_rw_type of peer input , as output_rw_type of DATA + map output_rw_type_map; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + GE_CHECK_NOTNULL(in_data_anchor); + auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + auto pre_node = pre_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(pre_node); + + auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx()); + GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(), + pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str()); + if (pre_output_rw_type == OutputRWType::kWriteable) { + // insert identity + auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to insert identity"); + return ret; + } + GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), + pre_node->GetName().c_str(), node->GetName().c_str()); + } + output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), OutputRWType::kSoftRead)); + } + NodeInputOutputRWType output_rw_type{{}, output_rw_type_map}; + node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type)); + } + } + return SUCCESS; +} /// /// @brief Reverse traversal all subgraph and mark rw_type for Data/Netoutput. /// @param sub_graph_vecgs /// -Status MarkRWTypeForSubgraph(vector> sub_graph_vec) { +Status MarkRWTypeForAllSubgraph(const vector &sub_graph_vec) { for (auto iter = sub_graph_vec.rbegin(); iter != sub_graph_vec.rend(); ++iter) { auto parent_node = (*iter)->GetParentNode(); if (parent_node == nullptr) { @@ -298,61 +354,9 @@ Status MarkRWTypeForSubgraph(vector> sub_graph_vec if (parent_node->GetType() == WHILE) { continue; } - for (const auto &node : (*iter)->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (node->GetType() == DATA) { - // calc all input_rw_type of peer output , as input_rw_type of DATA. Index 0 is valid. - auto out_data_anchor = node->GetOutDataAnchor(0); - GE_CHECK_NOTNULL(out_data_anchor); - std::set anchor_rw_type_set; - for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - auto peer_in_node = peer_in_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_in_node); - auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_anchor->GetIdx()); - GELOGD("Input rw type of Node %s %dth input anchor is %s", peer_in_node->GetName().c_str(), - peer_in_anchor->GetIdx(), InputRWTypeToSerialString(input_rw_type).c_str()); - anchor_rw_type_set.emplace(static_cast(input_rw_type)); - } - auto anchor_rw_type = GetInputRwTypeInConflict(anchor_rw_type_set); - GELOGD("Input rw type of Node %s is %s", node->GetName().c_str(), - InputRWTypeToSerialString(anchor_rw_type).c_str()); - map input_rw_type_map{std::make_pair(0, anchor_rw_type)}; - NodeInputOutputRWType data_rw_type{input_rw_type_map}; - node_rwtype_map_.emplace(std::make_pair(node->GetName(), data_rw_type)); - } - - if (node->GetType() == NETOUTPUT) { - // calc all output_rw_type of peer input , as output_rw_type of DATA - map output_rw_type_map; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - GE_CHECK_NOTNULL(in_data_anchor); - auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(pre_out_anchor); - auto pre_node = pre_out_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(pre_node); - - auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx()); - GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(), - pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str()); - if (pre_output_rw_type == OutputRWType::kWriteable) { - // insert identity - auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); - GE_CHECK_NOTNULL(identity_node); - auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); - if (ret != SUCCESS) { - GELOGE(ret, "Fail to insert identity"); - return ret; - } - GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), - pre_node->GetName().c_str(), node->GetName().c_str()); - } - output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), OutputRWType::kSoftRead)); - } - NodeInputOutputRWType output_rw_type{{}, output_rw_type_map}; - node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type)); - } + auto ret = MarkRWTypeForSubgraph(*iter); + if (ret != SUCCESS) { + return ret; } } return SUCCESS; @@ -367,8 +371,8 @@ Status MarkRWTypeForSubgraph(vector> sub_graph_vec /// @param node /// @return is_near_subgraph /// -bool CheckIdentityIsNearSubgraph(const NodePtr &node) { - for (const auto &in_node : node->GetInDataNodes()) { +bool CheckIdentityIsNearSubgraph(const Node &node) { + for (const auto &in_node : node.GetInDataNodes()) { auto in_node_opdesc = in_node->GetOpDesc(); if (in_node_opdesc == nullptr) { continue; @@ -383,7 +387,7 @@ bool CheckIdentityIsNearSubgraph(const NodePtr &node) { } } - for (const auto &out_node : node->GetOutDataNodes()) { + for (const auto &out_node : node.GetOutDataNodes()) { auto out_node_opdesc = out_node->GetOpDesc(); if (out_node_opdesc == nullptr) { continue; @@ -419,74 +423,113 @@ ConflictResult GetConflictResultBetweenNode(const OutputRWType output_rw_type, c /// @return /// Status RemoveNoUseIdentity(const NodePtr &node) { - if (node->GetInDataNodes().empty()) { - return SUCCESS; - } - if (node->GetOutDataNodesSize() > 1) { + if (node->GetInDataNodes().empty() || node->GetOutDataNodesSize() > 1) { return SUCCESS; } if (node->GetOutDataNodesSize() == 1 && node->GetOutDataNodes().at(0)->GetType() == STREAMMERGE) { return SUCCESS; } - if (CheckIdentityIsNearSubgraph(node)) { + if (CheckIdentityIsNearSubgraph(*node)) { return SUCCESS; } - auto out_data_anchor = node->GetOutDataAnchor(kIdentityAnchorIndex); - GE_CHECK_NOTNULL(out_data_anchor); GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex)); auto pre_out_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor(); GE_CHECK_NOTNULL(pre_out_anchor); auto pre_node = pre_out_anchor->GetOwnerNode(); auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx()); + auto anchor_2_outnode_vec = NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kIdentityAnchorIndex); ConflictResult conflict_result = WRONG_GRAPH; - if (!out_data_anchor->GetPeerInDataAnchors().empty()) { - auto peer_in_data_anchor = out_data_anchor->GetPeerInDataAnchors().at(0); - GE_CHECK_NOTNULL(peer_in_data_anchor); - auto peer_node = peer_in_data_anchor->GetOwnerNode(); - auto peer_input_rw_type = GetInputRWTypeByIndex(*peer_node, peer_in_data_anchor->GetIdx()); + if (!anchor_2_outnode_vec.empty()) { + auto anchor_2_outnode = anchor_2_outnode_vec.at(0); + auto peer_input_rw_type = GetInputRWTypeByIndex(*anchor_2_outnode.second, anchor_2_outnode.first->GetIdx()); GELOGD("Pre Node %s %dth output rw type is %s, peer node %s %dth input rw type is %s.", pre_node->GetName().c_str(), pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(), - peer_node->GetName().c_str(), peer_in_data_anchor->GetIdx(), + anchor_2_outnode.second->GetName().c_str(), anchor_2_outnode.first->GetIdx(), InputRWTypeToSerialString(peer_input_rw_type).c_str()); conflict_result = GetConflictResultBetweenNode(pre_output_rw_type, peer_input_rw_type); } else { // identity node has no out data node, it can be removed conflict_result = DO_NOTHING; } + if (conflict_result != DO_NOTHING) { + return SUCCESS; + } - switch (conflict_result) { - case DO_NOTHING: { - GELOGI("No need insert Identity. Node %s need to remove.", node->GetName().c_str()); - auto ret = GraphUtils::IsolateNode(node, {0}); - if (ret != SUCCESS) { - GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str()); - return ret; - } - ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node); - if (ret != SUCCESS) { - GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str()); - return ret; - } - GELOGI("Pre node is %s and %dth output rw type is %s. Isolate and remove Identity node %s.", - pre_node->GetName().c_str(), pre_out_anchor->GetIdx(), - OutputRWTypeToSerialString(pre_output_rw_type).c_str(), node->GetName().c_str()); - return SUCCESS; + GELOGI("No need insert Identity. Node %s need to remove.", node->GetName().c_str()); + auto ret = GraphUtils::IsolateNode(node, {0}); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str()); + return ret; + } + ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str()); + return ret; + } + GELOGI("Pre node is %s and %dth output rw type is %s. Isolate and remove Identity node %s.", + pre_node->GetName().c_str(), pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(), + node->GetName().c_str()); + return SUCCESS; +} + +Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &peer_in_data_anchor, + const OutDataAnchorPtr &pre_out_data_anchor, NodePtr &pre_node) { + // 1.check peer in node RW type. + GE_CHECK_NOTNULL(peer_in_data_anchor); + auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_in_data_node); + auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); + auto ret = out_data_anchor->Unlink(peer_in_data_anchor); + auto old_identity = out_data_anchor->GetOwnerNode(); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + return ret; + } + if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { + auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); + GE_CHECK_NOTNULL(new_identity); + if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS && + GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", + pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + return INTERNAL_ERROR; } - default: - return SUCCESS; + + // 2. copy in-control-edge from dst to Identity + if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(), + new_identity->GetName().c_str()); + return INTERNAL_ERROR; + } + GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), + InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + } else { + // copy control edge to pre and peer node + if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS || + GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { + GELOGW("Fail to copy control edge from node %s.", old_identity->GetName().c_str()); + return FAILED; + } + // link identity pre node to next node directly + if (GraphUtils::AddEdge(pre_out_data_anchor, peer_in_data_anchor) != SUCCESS) { + GELOGW("Fail to link data edge from node %s to %s.", pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + return FAILED; + } + GELOGI("Node %s input rw type is %s, link data edge from Identity input node %s to out node %s directly.", + peer_in_data_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str(), + pre_node->GetName().c_str(), peer_in_data_node->GetName().c_str()); } return SUCCESS; } Status SplitIdentity(const NodePtr &node) { GE_CHECK_NOTNULL(node); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() != IDENTITY) { - return SUCCESS; - } auto out_data_anchor = node->GetOutDataAnchor(kIdentityAnchorIndex); GE_CHECK_NOTNULL(out_data_anchor); if (out_data_anchor->GetPeerInDataNodesSize() <= 1) { @@ -498,56 +541,17 @@ Status SplitIdentity(const NodePtr &node) { GE_CHECK_NOTNULL(pre_out_data_anchor); auto pre_node = pre_out_data_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(pre_node); + Status ret = SUCCESS; for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - // 1.check peer in node RW type. - GE_CHECK_NOTNULL(peer_in_data_anchor); - auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_in_data_node); - auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); - auto ret = out_data_anchor->Unlink(peer_in_data_anchor); + ret = SplitIdentityAlongAnchor(out_data_anchor, peer_in_data_anchor, pre_out_data_anchor, pre_node); if (ret != SUCCESS) { - GELOGE(ret, "Failed to unlink from %s %dth out to %s.", node->GetName().c_str(), out_data_anchor->GetIdx(), - peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + GELOGE(ret, "Split identity node along anchor failed."); return ret; } - if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { - auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); - GE_CHECK_NOTNULL(identity_node); - ret = GraphUtils::AddEdge(pre_out_data_anchor, identity_node->GetInDataAnchor(kIdentityAnchorIndex)); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - return INTERNAL_ERROR; - } - ret = GraphUtils::AddEdge(identity_node->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - return INTERNAL_ERROR; - } - // 2. copy in-control-edge from dst to Identity - GraphUtils::CopyInCtrlEdges(peer_in_data_node, identity_node); - GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), - InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - } else { - // link identity pre node to next node directly - // todo control edge - if (GraphUtils::AddEdge(pre_out_data_anchor, peer_in_data_anchor) != SUCCESS) { - GELOGW("Fail to link data edge from node %s to %s.", pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - return FAILED; - } - GELOGI("Node %s intput rw type is %s, link data edge from Identity input node %s to out node %s directly.", - peer_in_data_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str(), - pre_node->GetName().c_str(), peer_in_data_node->GetName().c_str()); - } } // 2.isolate Identity node with no data output if (node->GetOutDataNodesSize() == 0) { - auto ret = GraphUtils::IsolateNode(node, {}); + ret = GraphUtils::IsolateNode(node, {}); if (ret != SUCCESS) { GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str()); return FAILED; @@ -616,7 +620,7 @@ Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_ return SUCCESS; } // 1.loop all subgraph, mark rw type from inside to outside - Status ret = MarkRWTypeForSubgraph(sub_graph_vec); + Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec); if (ret != SUCCESS) { GELOGE(ret, "Fail to mark rw type for subgraph."); return ret; @@ -671,7 +675,7 @@ Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { } GE_DUMP(compute_graph, "BeforeHandleMemConflict"); // 1.loop all subgraph, mark rw type from inside to outside - Status ret = MarkRWTypeForSubgraph(sub_graph_vec); + Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec); if (ret != SUCCESS) { GELOGE(ret, "Fail to mark rw type for subgraph."); return ret; diff --git a/src/ge/graph/partition/dynamic_shape_partition.cc b/src/ge/graph/partition/dynamic_shape_partition.cc index 903159b9..e5a33b37 100644 --- a/src/ge/graph/partition/dynamic_shape_partition.cc +++ b/src/ge/graph/partition/dynamic_shape_partition.cc @@ -139,6 +139,7 @@ std::string DynamicShapePartitioner::DebugString() const { size_t known = 0; size_t data = 0; size_t netoutput = 0; + size_t is_inputnode = 0; std::stringstream ss; ss << "All unknown shape nodes:" << std::endl; for (const auto &node : unknown_shape_nodes_) { @@ -153,10 +154,12 @@ std::string DynamicShapePartitioner::DebugString() const { data++; } else if (cluster->IsNetOutput()) { netoutput++; + } else if (cluster->IsInputNode()) { + is_inputnode++; } } ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known - << ", unknown:" << unknown << ", netoutput:" << netoutput << std::endl; + << ", unknown:" << unknown << ", netoutput:" << netoutput << ", is_inputnode:" << is_inputnode << std::endl; for (const auto &cluster : unique_clusters_) { ss << " " << cluster->DebugString() << std::endl; } @@ -195,8 +198,11 @@ Status DynamicShapePartitioner::InitClusters() { size_t rank = 0; for (const auto &node : graph->GetDirectNode()) { Cluster::Type type = Cluster::DATA; + bool is_input = ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) && node->GetInNodes().empty(); if (node->GetType() == DATA) { type = Cluster::DATA; + } else if (is_input) { + type = Cluster::INPUT_NODE; } else if (node->GetType() == NETOUTPUT) { type = Cluster::NETOUTPUT; } else if (unknown_shape_nodes_.count(node) > 0) { @@ -246,7 +252,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { auto cluster = ready_clusters.front(); ready_clusters.pop(); cluster->UpdateRank(rank++); - if (cluster->IsKnownShape()) { + if (cluster->IsKnownShape() || cluster->IsInputNode()) { ordered_cluster_.push_back(cluster); } for (const auto &out_cluster : cluster->Outputs()) { @@ -278,7 +284,7 @@ static std::string ToString(const std::vector &clusters) { } } // namespace -Status DynamicShapePartitioner::MergeClusters() { +void DynamicShapePartitioner::MergeClustersUnknownShape() { // Merge unknown shape clusters for (const auto &cluster : ordered_cluster_) { for (const auto &in_cluster : cluster->Inputs()) { @@ -295,8 +301,9 @@ Status DynamicShapePartitioner::MergeClusters() { } } } +} - REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); +void DynamicShapePartitioner::MergeClustersKnownShape() { // Merge known shape clusters for (const auto &cluster : ordered_cluster_) { if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { @@ -318,6 +325,32 @@ Status DynamicShapePartitioner::MergeClusters() { } } } +} + +void DynamicShapePartitioner::MergeClustersInputData() { + // Merge input clusters + std::shared_ptr cluster_pre = nullptr; + for (const auto &cluster : ordered_cluster_) { + if (!cluster->IsInputNode()) { + continue; + } + if (cluster_pre != nullptr) { + cluster_pre->Merge(cluster); + } else { + cluster_pre = cluster; + } + GELOGD("Success merge input node cluster from %lu to %lu.", cluster->Id(), cluster->Id()); + for (const auto &node : cluster->Nodes()) { + node_2_cluster_[node] = cluster_pre; + } + } +} + +Status DynamicShapePartitioner::MergeClusters() { + MergeClustersUnknownShape(); + REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); + MergeClustersKnownShape(); + MergeClustersInputData(); return SUCCESS; } @@ -448,6 +481,9 @@ std::string Cluster::DebugString() const { case DATA: ss << "DATA"; break; + case INPUT_NODE: + ss << "INPUT_NODE"; + break; case NETOUTPUT: ss << "NETOUTPUT"; break; @@ -483,6 +519,7 @@ bool Cluster::IsData() const { return type_ == DATA; }; bool Cluster::IsKnownShape() const { return type_ == KNOWN_SHAPE; }; bool Cluster::IsUnknownShape() const { return type_ == UNKNOWN_SHAPE; }; bool Cluster::IsNetOutput() const { return type_ == NETOUTPUT; }; +bool Cluster::IsInputNode() const { return type_ == INPUT_NODE; }; bool Cluster::IsRefVariable() const { if ((nodes_.size() == 1) && ((nodes_[0]->GetType() == VARIABLE) || (nodes_[0]->GetType() == VARIABLEV2))) { std::string ref_variable_name; @@ -492,27 +529,37 @@ bool Cluster::IsRefVariable() const { return false; } void Cluster::AddInput(ClusterPtr in) { - in_clusters_.insert(in); - in->out_clusters_.insert(shared_from_this()); + if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; + in_clusters_.insert(in_clusters_.end(), in); + if (std::find(in->out_clusters_.begin(), in->out_clusters_.end(), shared_from_this()) != in->out_clusters_.end()) + return; + in->out_clusters_.insert(in->out_clusters_.end(), shared_from_this()); }; void Cluster::RemoveInput(ClusterPtr in) { - in_clusters_.erase(in); - in->out_clusters_.erase(shared_from_this()); + in_clusters_.erase(std::remove(in_clusters_.begin(), in_clusters_.end(), in), in_clusters_.end()); + in->out_clusters_.erase(std::remove(in->out_clusters_.begin(), in->out_clusters_.end(), shared_from_this()), + in->out_clusters_.end()); }; void Cluster::AddOutput(ClusterPtr out) { - out_clusters_.insert(out); - out->in_clusters_.insert(shared_from_this()); + if (std::find(out_clusters_.begin(), out_clusters_.end(), out) != out_clusters_.end()) return; + out_clusters_.insert(out_clusters_.end(), out); + if (std::find(out->in_clusters_.begin(), out->in_clusters_.end(), shared_from_this()) != out->in_clusters_.end()) + return; + out->in_clusters_.insert(out->in_clusters_.end(), shared_from_this()); }; void Cluster::RemoveOutput(ClusterPtr out) { - out_clusters_.erase(out); - out->in_clusters_.erase(shared_from_this()); + out_clusters_.erase(std::remove(out_clusters_.begin(), out_clusters_.end(), out), out_clusters_.end()); + out->in_clusters_.erase(std::remove(out->in_clusters_.begin(), out->in_clusters_.end(), shared_from_this()), + out->in_clusters_.end()); }; void Cluster::Merge(ClusterPtr other) { nodes_.insert(nodes_.end(), other->nodes_.begin(), other->nodes_.end()); - other->in_clusters_.erase(shared_from_this()); - other->out_clusters_.erase(shared_from_this()); - in_clusters_.erase(other); - out_clusters_.erase(other); + other->in_clusters_.erase(std::remove(other->in_clusters_.begin(), other->in_clusters_.end(), shared_from_this()), + other->in_clusters_.end()); + other->out_clusters_.erase(std::remove(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()), + other->out_clusters_.end()); + in_clusters_.erase(std::remove(in_clusters_.begin(), in_clusters_.end(), other), in_clusters_.end()); + out_clusters_.erase(std::remove(out_clusters_.begin(), out_clusters_.end(), other), out_clusters_.end()); auto in_clusters = other->in_clusters_; for (const auto &cluster : in_clusters) { cluster->RemoveOutput(other); @@ -555,7 +602,8 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { std::unordered_set backward_reached_clusters; std::vector path_clusters; - if (other->out_clusters_.count(shared_from_this()) == 0) { + if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) == + other->out_clusters_.end()) { return path_clusters; } path_clusters.push_back(other); @@ -590,8 +638,8 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { } return path_clusters; } -std::unordered_set Cluster::Inputs() const { return in_clusters_; }; -std::unordered_set Cluster::Outputs() const { return out_clusters_; }; +std::vector Cluster::Inputs() const { return in_clusters_; }; +std::vector Cluster::Outputs() const { return out_clusters_; }; std::vector Cluster::Nodes() const { return nodes_; }; void Cluster::AddFrameInput(InDataAnchorPtr anchor) { @@ -617,7 +665,7 @@ InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_-> OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; Status Cluster::BuildFrame() { - if (IsUnknownShape() || IsKnownShape()) { + if (IsUnknownShape() || IsKnownShape() || IsInputNode()) { return BuildPartitionFrame(); } else { auto node = nodes_.front(); @@ -653,8 +701,10 @@ Status Cluster::BuildFrame() { Status Cluster::BuildPartitionFrame() { auto graph = partitioner_->root_graph_; bool is_unknown_shape = IsUnknownShape(); - std::string sub_graph_name = - graph->GetName() + "_sub_" + std::to_string(unique_id_) + (is_unknown_shape ? "_unknow" : "_know"); + bool is_input = IsInputNode(); + string known_name = (is_unknown_shape ? "_unknow" : "_know"); + string sub_graph_name_patten = (is_input ? "_input" : known_name); + std::string sub_graph_name = graph->GetName() + "_sub_" + std::to_string(unique_id_) + sub_graph_name_patten; subgraph_ = MakeShared(sub_graph_name); REQUIRE_NOT_NULL(subgraph_, "Failed new memory for subgraph."); auto partition_op = MakeShared("PartitionedCall_" + std::to_string(unique_id_++), "PartitionedCall"); diff --git a/src/ge/graph/partition/dynamic_shape_partition.h b/src/ge/graph/partition/dynamic_shape_partition.h index ba349b1c..b851a084 100644 --- a/src/ge/graph/partition/dynamic_shape_partition.h +++ b/src/ge/graph/partition/dynamic_shape_partition.h @@ -32,7 +32,7 @@ class DynamicShapePartitioner { // DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. class Cluster : public std::enable_shared_from_this { public: - enum Type { DATA, NETOUTPUT, KNOWN_SHAPE, UNKNOWN_SHAPE }; + enum Type { DATA, INPUT_NODE, NETOUTPUT, KNOWN_SHAPE, UNKNOWN_SHAPE }; Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) : id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) { nodes_.push_back(node); @@ -46,8 +46,9 @@ class DynamicShapePartitioner { bool IsKnownShape() const; bool IsUnknownShape() const; bool IsNetOutput() const; - std::unordered_set> Inputs() const; - std::unordered_set> Outputs() const; + std::vector> Inputs() const; + std::vector> Outputs() const; + bool IsInputNode() const; std::vector Nodes() const; bool IsRefVariable() const; // Cluster modify functions @@ -86,8 +87,8 @@ class DynamicShapePartitioner { size_t min_; // maximum topological order size_t max_; // minimum topological order Type type_; - std::unordered_set> in_clusters_; - std::unordered_set> out_clusters_; + std::vector> in_clusters_; + std::vector> out_clusters_; std::vector nodes_; // Fileds for build partitoned call and subgraph DynamicShapePartitioner *partitioner_; // Not owned, the partitioner this cluster belongs to @@ -121,7 +122,14 @@ class DynamicShapePartitioner { // merge all the clusters in the path(s) between the two clusters // 2) Iterate through the KNOWN_SHAPE clusters, if the input is KNOWN_SHAPE, and // and there's only one path between the two clusters , merge the two clusters + // 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA Status MergeClusters(); + // Merge clusters step1 + void MergeClustersUnknownShape(); + // Merge clusters step2 + void MergeClustersKnownShape(); + // Merge clusters step3 + void MergeClustersInputData(); // Topological sort clusters after merge unknow shape clusters. Status TopologicalSortClusters(); // Deduplicate merged clusters diff --git a/src/ge/graph/passes/assign_pass.cc b/src/ge/graph/passes/assign_pass.cc new file mode 100644 index 00000000..fe287f90 --- /dev/null +++ b/src/ge/graph/passes/assign_pass.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/assign_pass.h" + +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "graph/utils/graph_utils.h" +#include "graph/debug/ge_attr_define.h" + +namespace { +const uint32_t kValidInputNodeOutputNum = 1; +const int32_t kAssignRefInputIndex = 0; +const int32_t kAssignValueInputIndex = 1; +} // namespace + +namespace ge { +Status AssignPass::Run(NodePtr &node) { + GELOGD("AssignPass running"); + if (node->GetType() != ASSIGN) { + GELOGD("No need run AssignPass on [%s, %s].", node->GetName().c_str(), node->GetType().c_str()); + return SUCCESS; + } + + const auto &ref_in_anchor = node->GetInDataAnchor(kAssignRefInputIndex); + const auto &value_in_anchor = node->GetInDataAnchor(kAssignValueInputIndex); + if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) { + GELOGE(FAILED, "In data anchor is null, node:%s", node->GetName().c_str()); + return FAILED; + } + const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor(); + const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor(); + if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) { + GELOGE(FAILED, "Peer data anchor is null, node:%s", node->GetName().c_str()); + return FAILED; + } + + if (IsCondMatch(node, ref_peer_anchor, value_peer_anchor)) { + /// + /// variable not-const not-const + /// \ / | + /// \ / | + /// Assign ----> variable + /// | | + /// | | + /// node node + /// + GELOGI("Optimization for assign_node %s start", node->GetName().c_str()); + if (IsolateAndDeleteNode(node, {kAssignRefInputIndex}) != SUCCESS) { + GELOGE(FAILED, "Isolate and delete assign_node %s failed.", node->GetName().c_str()); + return FAILED; + } + AddNodeDeleted(node); + + const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc(); + const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc(); + if ((ref_input == nullptr) || (value_input == nullptr)) { + GELOGE(FAILED, "value input is null"); + return FAILED; + } + if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME, + ref_input->GetName())) { + GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed."); + return FAILED; + } + + // variable has and only has one input + if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str()); + return FAILED; + } + if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str()); + return FAILED; + } + } + + GELOGD("AssignPass success"); + return SUCCESS; +} + +/// +/// @brief Check if need optimize for assign_node +/// @param [in] assign_node +/// @param [in] peer_data_anchor for ref_input of assign_node +/// @param [in] peer_data_anchor for value_input of assign_node +/// @return Status +/// +bool AssignPass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_peer_anchor, + const OutDataAnchorPtr &value_peer_anchor) { + GELOGD("Check if assign_node %s match optimization condition, ref_input: %s, value_input: %s", + node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(), + value_peer_anchor->GetOwnerNode()->GetName().c_str()); + + const std::string &value_type = value_peer_anchor->GetOwnerNode()->GetType(); + if ((value_type == CONSTANTOP) || (value_type == CONSTANT)) { + GELOGD("value input is const"); + return false; + } + + const std::string &ref_type = ref_peer_anchor->GetOwnerNode()->GetType(); + if ((ref_type != VARIABLE) && (ref_type != VARIABLEV2)) { + GELOGD("ref input is not var"); + return false; + } + if (!ref_peer_anchor->GetOwnerNode()->GetInDataNodes().empty()) { + GELOGD("ref input has data input"); + return false; + } + + if ((ref_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum) || + (value_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum)) { + GELOGD("ref / value input has other output(s)"); + return false; + } + + GELOGD("Optimization condition matches, assign_node: %s", node->GetName().c_str()); + return true; +} +} // namespace ge diff --git a/src/ge/graph/passes/switch_fusion_pass.h b/src/ge/graph/passes/assign_pass.h similarity index 53% rename from src/ge/graph/passes/switch_fusion_pass.h rename to src/ge/graph/passes/assign_pass.h index 10ba5dad..d7dc5138 100644 --- a/src/ge/graph/passes/switch_fusion_pass.h +++ b/src/ge/graph/passes/assign_pass.h @@ -14,24 +14,26 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_SWITCH_FUSION_PASS_H_ -#define GE_GRAPH_PASSES_SWITCH_FUSION_PASS_H_ +#ifndef GE_GRAPH_PASSES_ASSIGN_PASS_H_ +#define GE_GRAPH_PASSES_ASSIGN_PASS_H_ -#include #include "graph/passes/base_pass.h" + namespace ge { -class SwitchFusionPass : public BaseNodePass { +class AssignPass : public BaseNodePass { public: Status Run(NodePtr &node) override; private: - Status FuseSwitchGroup(); - Status RemoveSwitchBetweenTwoNode(const int switch_out_anchor_idx, const NodePtr &switch_node); - Status FuseSwitchNodesToOne(NodePtr &remain_switch, const std::set switch_nodes_set); - const string GetFusionRoadId(const string branch_id, const NodePtr &switch_node); - NodePtr InsertIdentityNode(const NodePtr &remain_switch, const int out_data_anchor_idx); - map> switch_group_map_; + /// + /// @brief Check if need optimize for assign_node + /// @param [in] assign_node + /// @param [in] peer_data_anchor for ref_input of assign_node + /// @param [in] peer_data_anchor for value_input of assign_node + /// @return Status + /// + static bool IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_peer_anchor, + const OutDataAnchorPtr &value_peer_anchor); }; } // namespace ge - -#endif // GE_GRAPH_PASSES_SWITCH_FUSION_PASS_H_ +#endif // GE_GRAPH_PASSES_ASSIGN_PASS_H_ diff --git a/src/ge/graph/passes/attach_stream_label_pass.cc b/src/ge/graph/passes/attach_stream_label_pass.cc index 9962821b..b8065325 100644 --- a/src/ge/graph/passes/attach_stream_label_pass.cc +++ b/src/ge/graph/passes/attach_stream_label_pass.cc @@ -183,14 +183,12 @@ Status AttachStreamLabelPass::UpdateEnterNode() { std::unordered_map> enter_active_map; for (const auto &enter_node : enter_nodes_) { for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { - if (out_ctrl_node->GetType() != STREAMACTIVE) { - continue; - } - auto iter = enter_active_map.find(out_ctrl_node); - if (iter == enter_active_map.end()) { - enter_active_map[out_ctrl_node] = {enter_node}; - } else { - iter->second.emplace_back(enter_node); + if (out_ctrl_node->GetType() == STREAMACTIVE) { + if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { + enter_active_map[out_ctrl_node] = {enter_node}; + } else { + enter_active_map[out_ctrl_node].emplace_back(enter_node); + } } } } @@ -204,17 +202,29 @@ Status AttachStreamLabelPass::UpdateEnterNode() { NodePtr active_node = pair.first; GE_CHECK_NOTNULL(active_node); std::vector active_label_list; - if (!AttrUtils::GetListStr(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list) || - (active_label_list.size() != 1) || active_label_list[0].empty()) { + bool get_attr = AttrUtils::GetListStr(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list) && + (active_label_list.size() == 1) && !active_label_list[0].empty(); + if (!get_attr) { GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ACTIVE_LABEL_LIST failed, node: %s.", active_node->GetName().c_str()); return INTERNAL_ERROR; } std::stack enter_nodes; + std::string batch_label; for (const auto &enter_node : pair.second) { enter_nodes.emplace(enter_node); + std::string tmp_label; + (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); + if (!tmp_label.empty()) { + if (batch_label.empty()) { + batch_label = tmp_label; + } else if (batch_label != tmp_label) { + GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str()); + return FAILED; + } + } } - if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { + if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) { GELOGE(FAILED, "Update stream_label for loop_branch failed."); return FAILED; } @@ -234,28 +244,16 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no GE_CHECK_NOTNULL(active_node); (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); - bool same_flag = true; - for (const auto &enter_node : enter_nodes) { - std::string tmp_label; - (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, tmp_label); - if (tmp_label.empty() || (stream_label == tmp_label)) { - continue; - } - same_flag = false; - break; - } - if (stream_label.empty()) { - if (same_flag) { - stream_label = active_node->GetName(); - } else { - GELOGW("stream_label of enter_active is empty while stream_label of some enter_node is not."); - return SUCCESS; - } + GELOGW("stream_label of enter_active & enter_nodes is empty."); + return SUCCESS; } for (const auto &enter_node : enter_nodes) { - GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); + GE_CHECK_NOTNULL(enter_node->GetOpDesc()); + if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) { + GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); + } } GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); return SUCCESS; @@ -265,10 +263,11 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no /// @brief Update stream_label for loop_branch /// @param [in] enter_nodes /// @param [in] stream_label +/// @param [in] batch_label /// @return Status /// -Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, - const std::string &stream_label) { +Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label, + const std::string &batch_label) { std::stack nodes(enter_nodes); NodePtr cur_node = nullptr; while (!nodes.empty()) { @@ -277,8 +276,16 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_ for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { OpDescPtr out_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(out_desc); + std::string tmp_label; + (void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label); + if (!tmp_label.empty() && (tmp_label != batch_label)) { + continue; + } std::string out_type = out_desc->GetType(); - if (out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER)) { + bool need_skip = + out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || + (((cur_node->GetType() == ENTER) || (cur_node->GetType() == REFENTER)) && (out_type == STREAMACTIVE)); + if (need_skip) { continue; } GELOGD("Attach label %s to node: %s.", stream_label.c_str(), out_node->GetName().c_str()); diff --git a/src/ge/graph/passes/attach_stream_label_pass.h b/src/ge/graph/passes/attach_stream_label_pass.h index fc6abd30..5820480d 100644 --- a/src/ge/graph/passes/attach_stream_label_pass.h +++ b/src/ge/graph/passes/attach_stream_label_pass.h @@ -62,9 +62,11 @@ class AttachStreamLabelPass : public GraphPass { /// @brief Update stream_label for loop_branch /// @param [in] enter_nodes /// @param [in] stream_label + /// @param [in] batch_label /// @return Status /// - static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); + static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label, + const std::string &batch_label); /// /// @brief Update stream_label start with enter nodes diff --git a/src/ge/graph/passes/cast_remove_pass.cc b/src/ge/graph/passes/cast_remove_pass.cc index f7ff941c..ab4f2098 100644 --- a/src/ge/graph/passes/cast_remove_pass.cc +++ b/src/ge/graph/passes/cast_remove_pass.cc @@ -50,6 +50,10 @@ Status CastRemovePass::Run(NodePtr &node) { return PARAM_INVALID; } + if (!CheckPrecisionLoss(nodes_to_fuse)) { + return SUCCESS; + } + DataType type = DT_UNDEFINED; if (!HasSameDataType(op_desc, end_op_desc, type)) { return SUCCESS; @@ -60,6 +64,15 @@ Status CastRemovePass::Run(NodePtr &node) { return SUCCESS; } +bool CastRemovePass::CheckPrecisionLoss(const std::vector &nodes_to_fuse) { + for (const NodePtr &node : nodes_to_fuse) { + if (!TransOpUtil::CheckPrecisionLoss(node)) { + return false; + } + } + return true; +} + bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op_desc, DataType &type) const { if (begin_op_desc->GetName() == end_op_desc->GetName()) { return false; @@ -81,7 +94,7 @@ bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op // op1->TransData->TransposeD->TransData->op2 Status CastRemovePass::RemoveCast(DataType &type, std::vector &nodes_to_fuse) { string cast_name; - for (NodePtr node : nodes_to_fuse) { + for (NodePtr &node : nodes_to_fuse) { if (node->GetType() == CAST) { GELOGI("CastRemovePass, remove Cast %s.", node->GetName().c_str()); cast_name = node->GetName(); diff --git a/src/ge/graph/passes/cast_remove_pass.h b/src/ge/graph/passes/cast_remove_pass.h index e889781f..67fa697e 100644 --- a/src/ge/graph/passes/cast_remove_pass.h +++ b/src/ge/graph/passes/cast_remove_pass.h @@ -26,6 +26,7 @@ class CastRemovePass : public BaseNodePass { Status Run(NodePtr &node) override; private: + bool CheckPrecisionLoss(const std::vector &nodes_to_fuse); bool HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op_desc, DataType &type) const; Status RemoveCast(DataType &type, std::vector &nodes_to_fuse); NodePtr GetTheEndNode(NodePtr begin_node, std::vector &nodes_to_fuse); diff --git a/src/ge/graph/passes/cond_pass.cc b/src/ge/graph/passes/cond_pass.cc index 2f3f9333..03ca9009 100644 --- a/src/ge/graph/passes/cond_pass.cc +++ b/src/ge/graph/passes/cond_pass.cc @@ -47,6 +47,14 @@ Status CondPass::Run(NodePtr &node) { GE_CHECK_NOTNULL(op_desc); GELOGI("Handle cond for node %s.", op_desc->GetName().c_str()); GeTensorDesc cond_tensor = op_desc->GetInputDesc(cond_in_anchor->GetIdx()); + if (cond_tensor.MutableShape().GetDim(0) == UNKNOWN_DIM_NUM) { + GELOGI("Output tensor rank of Cond is unknown."); + if (cond_tensor.GetDataType() == DT_STRING) { + GE_CHK_STATUS_RET(HandleStringCond(graph, cond_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", + op_desc->GetName().c_str()) + } + return SUCCESS; + } if (!cond_tensor.GetShape().IsScalar()) { GE_CHK_STATUS_RET(HandleNonScalarCond(graph, cond_out_anchor, cond_in_anchor), "HandleNonScalarCond for %s failed.", op_desc->GetName().c_str()) @@ -255,8 +263,8 @@ Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr GeTensorDesc out_tensor = in_anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(out_anchor->GetIdx()); out_tensor.SetDataType(DT_INT32); out_tensor.SetOriginDataType(DT_INT32); - out_tensor.SetShape(GeShape()); - out_tensor.SetOriginShape(GeShape()); + out_tensor.SetShape(in_tensor.GetShape()); + out_tensor.SetOriginShape(in_tensor.GetOriginShape()); OpDescBuilder op_desc_builder(out_anchor->GetOwnerNode()->GetName() + "_" + type, type); OpDescPtr op_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build(); diff --git a/src/ge/graph/passes/flow_ctrl_pass.cc b/src/ge/graph/passes/flow_ctrl_pass.cc index 03f8d5a6..430cf86d 100644 --- a/src/ge/graph/passes/flow_ctrl_pass.cc +++ b/src/ge/graph/passes/flow_ctrl_pass.cc @@ -32,11 +32,6 @@ namespace ge { Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { GE_CHECK_NOTNULL(compute_graph); - if (AddGlobalStepVariableNode(compute_graph) != SUCCESS) { - GELOGE(FAILED, "Add global step variable node fail."); - return FAILED; - } - if (!PassUtils::IsNeedTrainIteFlowCtrl(compute_graph)) { GELOGI("No need FlowCtrl for graph %u", compute_graph->GetGraphID()); return NOT_CHANGED; @@ -193,8 +188,9 @@ Status FlowCtrlPass::AddGlobalStepVariableNode(ComputeGraphPtr &compute_graph) { GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); return SUCCESS; } - - if (compute_graph->GetParentGraph() != nullptr) { // Global step just add to main graph. + // Global step just add to main graph's netoutput node.And the main graph must be known shape + if ((compute_graph->GetParentGraph() != nullptr) || + ((compute_graph->GetParentGraph() == nullptr) && (GraphUtils::IsUnknownShapeGraph(compute_graph)))) { GELOGD("Subgraph %s no need global step variable.", compute_graph->GetName().c_str()); return SUCCESS; } diff --git a/src/ge/graph/passes/folding_pass.cc b/src/ge/graph/passes/folding_pass.cc index 44dbc182..8281db5d 100644 --- a/src/ge/graph/passes/folding_pass.cc +++ b/src/ge/graph/passes/folding_pass.cc @@ -174,7 +174,7 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { if (in_node == nullptr) { continue; } - if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { + if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) { GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); auto ret = in_node_anchor->Unlink(in_data_anchor); if (ret != SUCCESS) { diff --git a/src/ge/graph/passes/folding_pass.h b/src/ge/graph/passes/folding_pass.h index 9c8d3a7e..0ffd2eb2 100644 --- a/src/ge/graph/passes/folding_pass.h +++ b/src/ge/graph/passes/folding_pass.h @@ -33,9 +33,11 @@ bool IsNoNeedConstantFolding(const NodePtr &node); using IndexsToAnchors = std::map>; class FoldingPass : public BaseNodePass { + public: + static Status RunOpKernel(NodePtr &node, const vector &inputs, vector &outputs); + protected: Status Folding(NodePtr &node, vector &outputs); - static Status RunOpKernel(NodePtr &node, const vector &inputs, vector &outputs); private: Status AddConstNode(NodePtr &node, IndexsToAnchors indexes_to_anchors, std::vector &v_weight); diff --git a/src/ge/graph/passes/global_step_insert_pass.cc b/src/ge/graph/passes/global_step_insert_pass.cc new file mode 100644 index 00000000..460f6ad6 --- /dev/null +++ b/src/ge/graph/passes/global_step_insert_pass.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/global_step_insert_pass.h" + +#include +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" +#include "common/ge/ge_util.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/passes/pass_utils.h" + +namespace ge { +NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, const string &node_name, + const std::vector &input_list, + const std::vector &output_list) { + OpDescPtr op_desc = MakeShared(node_name, node_type); + GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(FAILED, "Make OpDesc failed"); return nullptr); + + for (auto &input_desc : input_list) { + graphStatus graph_status = op_desc->AddInputDesc(input_desc); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add node:%s intput desc failed, error=%u.", node_name.c_str(), graph_status); + return nullptr; + } + } + + for (auto &output_desc : output_list) { + graphStatus graph_status = op_desc->AddOutputDesc(output_desc); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add node:%s output desc failed, error=%u.", node_name.c_str(), graph_status); + return nullptr; + } + } + + GE_IF_BOOL_EXEC(compute_graph == nullptr, GELOGE(FAILED, "compute_graph is nullptr"); return nullptr); + NodePtr node = compute_graph->AddNode(op_desc); + GE_IF_BOOL_EXEC(node == nullptr, + GELOGE(FAILED, "add node failed, name:%s, type:%s.", node_name.c_str(), node_type.c_str()); + return nullptr); + + GELOGI("Insert op success, name:%s, type:%s.", node_name.c_str(), node_type.c_str()); + return node; +} + +Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { + NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); + if (output_node == nullptr) { + GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); + return SUCCESS; + } + + if (compute_graph->GetParentGraph() != nullptr) { + GELOGD("Subgraph %s no need global step variable.", compute_graph->GetName().c_str()); + return SUCCESS; + } + + NodePtr exist_node = compute_graph->FindNode(NODE_NAME_GLOBAL_STEP); + if (exist_node != nullptr) { + GELOGD("Node %s already exist, no need add.", NODE_NAME_GLOBAL_STEP.c_str()); + return SUCCESS; + } + // set global step tensor desc + GeTensorDesc tensor_desc(GeShape({1}), FORMAT_ND, DT_UINT64); + std::vector input_desc_list = {}; + std::vector output_desc_list = {tensor_desc}; + NodePtr global_step = InsertOp(compute_graph, VARIABLE, NODE_NAME_GLOBAL_STEP, input_desc_list, output_desc_list); + if (global_step == nullptr) { + GELOGE(FAILED, "Add global_step node failed, global_step is null."); + return FAILED; + } + + // add ctrl edges + graphStatus add_ret = GraphUtils::AddEdge(global_step->GetOutControlAnchor(), output_node->GetInControlAnchor()); + if (add_ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add global_step to netoutput edge failed, add_ret=%u.", add_ret); + return FAILED; + } + GELOGD("Add global_step to netoutput edge in graph %u success", compute_graph->GetGraphID()); + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/passes/global_step_insert_pass.h b/src/ge/graph/passes/global_step_insert_pass.h new file mode 100644 index 00000000..46bc85d6 --- /dev/null +++ b/src/ge/graph/passes/global_step_insert_pass.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_GLOBAL_STEP_INSERT_PASS_H_ +#define GE_GRAPH_PASSES_GLOBAL_STEP_INSERT_PASS_H_ + +#include +#include + +#include "common/ge_inner_error_codes.h" +#include "inc/graph_pass.h" + +namespace ge { +/// +/// Add global step op to the computeGraph when needed. +/// [Notice]: this pass must work before graph partitioner start work +/// in order to make the global step variable place in known subgraph +/// +class GlobalStepInsertPass : public GraphPass { + public: + /// + /// @param compute_graph graph + /// @return SUCCESS: do success + /// NOT_CHANGED : do nothing + /// Other: failed + /// + Status Run(ComputeGraphPtr compute_graph) override; + + private: + /// + /// Universal insert node to graph. + /// @param compute_graph graph + /// @param node_type inserted node type + /// @param node_name inserted node name + /// @param input_list input desc list + /// @param output_list output desc list + /// @return the inserted node. if insert failed return nullptr. + /// + NodePtr InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, const string &node_name, + const std::vector &input_list, const std::vector &output_list); +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_GLOBAL_STEP_INSERT_PASS_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/hccl_memcpy_pass.cc b/src/ge/graph/passes/hccl_memcpy_pass.cc index a9b3484b..b8787476 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.cc +++ b/src/ge/graph/passes/hccl_memcpy_pass.cc @@ -89,7 +89,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { /// @param [in] ge::OutDataAnchorPtr in_node /// @return ge::NodePtr /// -NodePtr HcclMemcpyPass::CreateMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { +NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { GE_IF_BOOL_EXEC(graph == nullptr, return nullptr); NodePtr pre_node = out_data_anchor->GetOwnerNode(); OpDescPtr pre_op_desc = pre_node->GetOpDesc(); @@ -98,30 +98,32 @@ NodePtr HcclMemcpyPass::CreateMemcpyNode(const ComputeGraphPtr &graph, const Out return nullptr; } - std::string node_name = pre_node->GetName() + "_" + MEMCPYASYNC; + std::string node_name = pre_node->GetName() + "_" + IDENTITY; node_name = CheckDuplicateName(node_name); - OpDescPtr op_desc = MakeShared(node_name.c_str(), MEMCPYASYNC); + OpDescPtr op_desc = MakeShared(node_name.c_str(), IDENTITY); if (op_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: MakeShared op_desc fail."); + GELOGE(INTERNAL_ERROR, "Create identity op: MakeShared op_desc fail."); return nullptr; } - GELOGI("Create MemcpyAsync op:%s.", op_desc->GetName().c_str()); + GELOGI("Create identity op:%s.", op_desc->GetName().c_str()); - graphStatus ret = op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); + graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add input desc fail."); + GELOGE(INTERNAL_ERROR, "Create identity op: add input desc fail."); return nullptr; } - ret = op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); + ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add output desc fail."); + GELOGE(INTERNAL_ERROR, "Create identity op: add output desc fail."); return nullptr; } + // because history reason ,this pass can not do work after constant fold so mark it + (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); NodePtr memcpy_node = graph->AddNode(op_desc); if (memcpy_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Insert MemcpyAsync node fail."); + GELOGE(INTERNAL_ERROR, "Insert identity node fail."); return nullptr; } @@ -155,7 +157,7 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) { Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, const InDataAnchorPtr &hccl_in_anchor) { GELOGI("The op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str()); - NodePtr memcpy_node = CreateMemcpyNode(graph, src_out_anchor); + NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); GE_CHECK_NOTNULL(memcpy_node); Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); diff --git a/src/ge/graph/passes/hccl_memcpy_pass.h b/src/ge/graph/passes/hccl_memcpy_pass.h index 13863bd6..44b40241 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.h +++ b/src/ge/graph/passes/hccl_memcpy_pass.h @@ -30,7 +30,7 @@ class HcclMemcpyPass : public GraphPass { Status ClearStatus() override; private: - NodePtr CreateMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); + NodePtr CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); std::string CheckDuplicateName(const std::string &node_name); diff --git a/src/ge/graph/passes/identity_pass.cc b/src/ge/graph/passes/identity_pass.cc index 1f4725bf..57b7c46d 100644 --- a/src/ge/graph/passes/identity_pass.cc +++ b/src/ge/graph/passes/identity_pass.cc @@ -21,6 +21,8 @@ #include "framework/common/debug/ge_log.h" #include "graph/common/omg_util.h" #include "graph/utils/node_utils.h" +#include "graph/utils/attr_utils.h" +#include "graph/debug/ge_attr_define.h" namespace ge { namespace { @@ -33,8 +35,14 @@ namespace { /// or as input of Netoutput of subgraph /// or as input of one node with subgraph /// or as output of one node with subgraph +/// 3. identity with attr no_need_constant_folding should not be deleted too Status CheckIdentityUsable(const NodePtr &node, bool &usable) { std::string node_type; + if (node->GetOpDesc()->HasAttr(ge::ATTR_NO_NEED_CONSTANT_FOLDING)) { + usable = true; + return SUCCESS; + } + for (auto &in_node : node->GetInDataNodes()) { auto in_node_opdesc = in_node->GetOpDesc(); GE_CHECK_NOTNULL(in_node_opdesc); @@ -47,7 +55,8 @@ Status CheckIdentityUsable(const NodePtr &node, bool &usable) { GE_CHK_STATUS_RET(GetOriginalType(in_node, node_type), "Failed to get node type from node %s", node->GetName().c_str()); - if ((node_type != SWITCH) && (node_type != REFSWITCH)) { + bool need_skip = (node_type != SWITCH) && (node_type != REFSWITCH) && (node_type != SWITCHN); + if (need_skip) { GELOGD("skip identity %s connected to switch", node->GetName().c_str()); break; } diff --git a/src/ge/graph/passes/infershape_pass.cc b/src/ge/graph/passes/infershape_pass.cc index 8b44d31b..7ed1ea8c 100644 --- a/src/ge/graph/passes/infershape_pass.cc +++ b/src/ge/graph/passes/infershape_pass.cc @@ -24,8 +24,6 @@ namespace ge { Status InferShapePass::Run(NodePtr &node) { auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); if (ret != GRAPH_SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E35003", {"opname", "err_msg"}, - {node->GetName(), "check your model!"}); GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } diff --git a/src/ge/graph/passes/memcpy_addr_async_pass.cc b/src/ge/graph/passes/memcpy_addr_async_pass.cc index 7cbacc23..3af40888 100644 --- a/src/ge/graph/passes/memcpy_addr_async_pass.cc +++ b/src/ge/graph/passes/memcpy_addr_async_pass.cc @@ -34,6 +34,17 @@ Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { return ret; } } + // handle data->netoutput, const->netoutput in root graph, use mem_addr_async to improve performance + if (op_desc->GetType() == NETOUTPUT) { + // check this netoutput is on root graph + if (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) { + Status ret = InsertMemAddrAsyncNodeBeforeNetoutput(node->GetOwnerComputeGraph(), node); + if (ret != SUCCESS) { + GELOGE(ret, "AddMemcpyAddrAsyncNode failed."); + return ret; + } + } + } } return SUCCESS; } @@ -179,7 +190,7 @@ void MemcpyAddrAsyncPass::FindUserData(const NodePtr &parent_node, uint32_t &par NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, const NodePtr &out_of_user_data) { - GELOGI("Start CreateMemcpyAddrAsyncNode."); + GELOGD("Start CreateMemcpyAddrAsyncNode."); OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC; @@ -242,4 +253,27 @@ Status MemcpyAddrAsyncPass::InsertMemcpyAddrAsyncNode(const OutDataAnchorPtr &ou return SUCCESS; } +Status MemcpyAddrAsyncPass::InsertMemAddrAsyncNodeBeforeNetoutput(const ComputeGraphPtr &graph, const NodePtr &node) { + GELOGI("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str()); + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto in_node = NodeUtils::GetInDataNodeByIndex(*node, in_data_anchor->GetIdx()); + GE_CHECK_NOTNULL(in_node); + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if ((in_node->GetType() != CONSTANT) && (in_node->GetType() != CONSTANTOP) && (in_node->GetType() != DATA)) { + continue; + } + GELOGI("Need to insert MemcpyAddrAsync before netoutput on parent graph."); + NodePtr memcpy_addr_async_node = CreateMemcpyAddrAsyncNode(graph, peer_out_anchor, in_node); + GE_IF_BOOL_EXEC(memcpy_addr_async_node == nullptr, GELOGE(INTERNAL_ERROR, "CreateMemcpyAddrAsyncNode failed."); + return INTERNAL_ERROR); + + Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_addr_async_node); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "InsertMemcpyAddrAsyncNode failed."); return ret); + GELOGI("Insert mem_addr_async node %s success between %s and %s.", memcpy_addr_async_node->GetName().c_str(), + in_node->GetName().c_str(), node->GetName().c_str()); + NodeUtils::UpdateIsInputConst(memcpy_addr_async_node); + } + NodeUtils::UpdateIsInputConst(node); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/memcpy_addr_async_pass.h b/src/ge/graph/passes/memcpy_addr_async_pass.h index 9d99e505..1f184bd5 100644 --- a/src/ge/graph/passes/memcpy_addr_async_pass.h +++ b/src/ge/graph/passes/memcpy_addr_async_pass.h @@ -35,6 +35,7 @@ class MemcpyAddrAsyncPass : public GraphPass { const NodePtr &out_of_user_data); Status InsertMemcpyAddrAsyncNode(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor, const NodePtr &node); + Status InsertMemAddrAsyncNodeBeforeNetoutput(const ComputeGraphPtr &graph, const NodePtr &node); NodePtr user_data_; NodePtr out_of_user_data_; diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.cc b/src/ge/graph/passes/merge_to_stream_merge_pass.cc index b785ddfa..34daa681 100644 --- a/src/ge/graph/passes/merge_to_stream_merge_pass.cc +++ b/src/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -89,6 +89,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); } + if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { + string batch_label; + (void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); + if (!batch_label.empty()) { + auto stream_merge_desc = stream_merge->GetOpDesc(); + GE_CHECK_NOTNULL(stream_merge_desc); + (void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label); + } + } + return AddMemcpyAsyncNodes(graph, stream_merge, false); } diff --git a/src/ge/graph/passes/multi_batch_clone_pass.cc b/src/ge/graph/passes/multi_batch_clone_pass.cc new file mode 100644 index 00000000..3390e783 --- /dev/null +++ b/src/ge/graph/passes/multi_batch_clone_pass.cc @@ -0,0 +1,610 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/multi_batch_clone_pass.h" + +#include "common/ge/ge_util.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "graph/preprocess/multi_batch_options.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "register/op_registry.h" + +namespace ge { +namespace { +constexpr uint8_t kDataInIndex = 0; +constexpr uint8_t kDataOutIndex = 0; +constexpr uint8_t kCaseArgIndex = 1; + +const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case"; +const std::string kMultiBatchIndexNode = "ascend_mbatch_shape_data"; +} // namespace + +Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { + if (graph->GetParentGraph() != nullptr) { + GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str()); + return SUCCESS; + } + + if (!multibatch::InitDynamicParams(batch_shapes_)) { + GELOGD("There is no multi-batch options, no need clone multi-batch graph"); + return SUCCESS; + } + + GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str()); + GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param"); + if (CollectIoNodes(graph) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Collect input output nodes failed"); + return INTERNAL_ERROR; + } + + (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); + ComputeGraphPtr branch = MakeShared(graph->GetName()); + if (branch == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch graph failed"); + return OUT_OF_MEMORY; + } + (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); + + graph->Swap(*branch); + if (CreateRootGraph(graph) != SUCCESS) { + return FAILED; + } + + if (CreateSubgraphs(graph, branch) != SUCCESS) { + return FAILED; + } + + GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); + GELOGD("MultiBatchClonePass Leave"); + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Collect input output node from original graph. +/// @param [in] const ComputeGraphPtr &graph: original graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { + for (const auto &node : graph->GetDirectNode()) { + if (node->GetType() == DATA) { + all_data_nodes_.emplace_back(node); + } else if (node->GetType() == CONSTANT) { + all_const_nodes_.emplace_back(node); + } else if (node->GetType() == NETOUTPUT) { + all_output_nodes_.emplace_back(node); + } + + // If the node save as input/output node, delete record. + (void)graph->RemoveInputNode(node); + (void)graph->RemoveOutputNode(node); + } + + if (all_data_nodes_.empty() || all_output_nodes_.size() != 1) { + GELOGE(FAILED, "data nodes: %zu, output nodes: %zu", all_data_nodes_.size(), all_output_nodes_.size()); + return FAILED; + } + + int64_t data_index = 0; + for (size_t i = 0; i < all_data_nodes_.size(); ++i) { + const auto &op_desc = all_data_nodes_[i]->GetOpDesc(); + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i); + } + } + + const auto &output = all_output_nodes_[0]; + for (size_t i = 0; i < output->GetAllInDataAnchorsSize(); ++i) { + const auto in_anchor = output->GetInDataAnchor(i); + const auto out_anchor = in_anchor->GetPeerOutAnchor(); + const auto data_node = out_anchor->GetOwnerNode(); + if (data_node->GetType() == DATA) { + direct_output_[i] = data_node->GetName(); + GE_CHK_GRAPH_STATUS_RET( + GraphUtils::RemoveEdge(data_node->GetOutDataAnchor(kDataOutIndex), output->GetInDataAnchor(i)), + "Remove edge failed"); + } + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Create nodes for root graph. +/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { + uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size(); + uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize(); + + OpDescBuilder op_builder(kMultiBatchCaseNode, CASE); + op_builder.AddInput("branch_index").AddDynamicInput("input", input_num).AddDynamicOutput("output", output_num); + const OpDescPtr op_desc = op_builder.Build(); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch case desc failed"); + return OUT_OF_MEMORY; + } + + op_desc->RegisterSubgraphIrName("branches", kDynamic); + case_node_ = graph->AddNode(op_desc); + if (case_node_ == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch case node failed"); + return OUT_OF_MEMORY; + } + + uint32_t batch_num = static_cast(batch_shapes_.size()); + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { + GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_NUM failed, Case: %s.", op_desc->GetName().c_str()); + return FAILED; + } + + for (uint32_t i = 0; i < batch_num; i++) { + const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); + if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shapes_[i])) { + GELOGE(FAILED, "Set attr ATTR_NAME_PRED_VALUE failed, Case: %s.", op_desc->GetName().c_str()); + return FAILED; + } + } + + GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed"); + + GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed"); + GE_CHK_STATUS_RET(CreateInputNode(graph), "Create input node failed"); + GE_CHK_STATUS_RET(CreateConstNode(graph), "Create const node failed"); + GE_CHK_STATUS_RET(CreateOutputNode(graph), "Create output node failed"); + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Create index node for root graph. +/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { + // Data --> MapIndex --> Case + const OpDescPtr op_desc = MakeShared(kMultiBatchIndexNode, DATA); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch index node failed"); + return FAILED; + } + + GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32); + if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add output desc failed"); + return FAILED; + } + if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add output desc failed"); + return FAILED; + } + + size_t data_index = all_data_nodes_.size(); + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); + (void)AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true); + + index_node_ = graph->AddNode(op_desc); + if (index_node_ == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch case node failed"); + return OUT_OF_MEMORY; + } + + if (GraphUtils::AddEdge(index_node_->GetOutDataAnchor(0), case_node_->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", index_node_->GetName().c_str(), + case_node_->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Create input node for root graph. +/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { + // Data --> Case + std::vector all_data_nodes; + const size_t arg_index = kCaseArgIndex; + for (size_t i = 0; i < all_data_nodes_.size(); ++i) { + const auto &node = all_data_nodes_[i]; + const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch Data node failed, name: %s", node->GetName().c_str()); + return FAILED; + } + + if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) { + return FAILED; + } + + op_desc->SetName(node->GetName()); + const NodePtr &data = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); + if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", data->GetName().c_str(), + case_node_->GetName().c_str()); + return FAILED; + } + + if (SetMaxShapeToData(data) != SUCCESS) { + return FAILED; + } + all_data_nodes.emplace_back(data); + } + + all_data_nodes_.swap(all_data_nodes); + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Create Const node for root graph. +/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { + // Const --> Case + std::vector all_const_nodes; + const size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); + for (size_t i = 0; i < all_const_nodes_.size(); ++i) { + const auto &node = all_const_nodes_[i]; + const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch Const node failed, name: %s", node->GetName().c_str()); + return FAILED; + } + + op_desc->SetName(node->GetName()); + if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) { + return FAILED; + } + + const NodePtr &data = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); + if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", data->GetName().c_str(), + case_node_->GetName().c_str()); + return FAILED; + } + all_const_nodes.emplace_back(data); + } + + size_t data_index = all_data_nodes_.size(); + for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data. + const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc(); + op_desc->SetType(DATA); + (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight. + + // Const no InputDesc, Data need InputDesc. + (void)op_desc->AddInputDesc(op_desc->GetOutputDesc(kDataOutIndex)); + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); + } + + all_const_nodes_.swap(all_const_nodes); + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Create output node for root graph. +/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { + const auto &output = all_output_nodes_[0]; + const OpDescPtr op_desc = AttrUtils::CopyOpDesc(output->GetOpDesc()); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch output node failed"); + return FAILED; + } + + if (GraphUtils::CopyTensorAttrs(op_desc, output) != GRAPH_SUCCESS) { + return FAILED; + } + + op_desc->SetName(output->GetName()); + const NodePtr &node = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); + + for (size_t i = 0; i < case_node_->GetAllOutDataAnchorsSize(); ++i) { + const auto it = direct_output_.find(i); + if (it == direct_output_.end()) { + if (GraphUtils::AddEdge(case_node_->GetOutDataAnchor(i), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between Case:%s to NetOutput:%s", case_node_->GetName().c_str(), + node->GetName().c_str()); + return FAILED; + } + } else { + const auto data_node = graph->FindNode(it->second); + if (data_node == nullptr) { + GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Data node:%s not found", it->second.c_str()); + return GE_GRAPH_GRAPH_NODE_NULL; + } + if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add edge between Data:%s to NetOutput:%s", data_node->GetName().c_str(), + node->GetName().c_str()); + return FAILED; + } + } + } + + all_output_nodes_.clear(); + all_output_nodes_.emplace_back(node); + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Set max shape to Data node in root graph. +/// @param [in] const NodePtr &data: data in Root/Case graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { + auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); + const auto &dims = data_shape.GetDims(); + if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { + return SUCCESS; + } + + size_t max_shape_index = 0; + int64_t max_size = 0; + for (size_t i = 0; i < batch_shapes_.size(); ++i) { + int64_t size = 1; + for (auto dim : batch_shapes_[i]) { + if (INT64_MAX / dim < size) { + GELOGE(PARAM_INVALID, "The shape %s size overflow", formats::ShapeToString(batch_shapes_[i]).c_str()); + return PARAM_INVALID; + } + size *= dim; + } + if (size > max_size) { + max_size = size; + max_shape_index = i; + } + } + + return SetShapeToData(batch_shapes_[max_shape_index], data, data_shape); +} + +/// +/// @ingroup ge +/// @brief Set shape to Data node in branch. +/// @param [in] const NodePtr &data: data in branch. +/// @param [in] const std::vector &shapes: dims of shape. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::UpdataShapeToData(const NodePtr &data, const vector &shapes) { + auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); + const auto &dims = data_shape.GetDims(); + if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { + return SUCCESS; + } + + (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); + return SetShapeToData(shapes, data, data_shape); +} + +/// +/// @ingroup ge +/// @brief Set max shape to Data node in root graph. +/// @param [in] const std::vector &shapes: dims of shape. +/// @param [in] const NodePtr &data: data in Root/Case graph. +/// @param [in] GeShape &data_shape: dims of data node. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::SetShapeToData(const vector &shapes, const NodePtr &data, GeShape &data_shape) { + // must not be error, the calc result has been checked in function InsertSwitchNForData + if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { + return INTERNAL_ERROR; + } + + if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); + return INTERNAL_ERROR; + } + + GELOGI("Update %s input/output shape to %s", data->GetName().c_str(), formats::ShapeToString(data_shape).c_str()); + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Create nodes for root graph. +/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. +/// @param [in] const ComputeGraphPtr &branch: original graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) { + const std::string name = graph->GetName() + "_branche_"; + const auto &op_desc = case_node_->GetOpDesc(); + for (size_t i = 0; i < batch_shapes_.size(); ++i) { + std::vector input_nodes; + std::vector output_nodes; + const std::string prefix = "branche_" + std::to_string(i) + "_"; + ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, prefix, input_nodes, output_nodes); + if (subgraph == nullptr) { + GELOGE(FAILED, "Create multi-batch case node failed"); + return FAILED; + } + + subgraph->SetName(name + std::to_string(i)); + subgraph->SetParentNode(case_node_); + subgraph->SetParentGraph(graph); + (void)AttrUtils::SetStr(subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); + all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); + + graph->AddSubgraph(subgraph->GetName(), subgraph); + + const std::string key_name = "branches" + std::to_string(i); + op_desc->AddSubgraphName(key_name); + op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); + + for (const auto &data : input_nodes) { + GE_CHK_STATUS_RET(UpdataShapeToData(data, batch_shapes_[i]), "Update %s failed", subgraph->GetName().c_str()); + } + } + + // Origninal graph take as first subgraph, update node name. + for (const auto &n : branch->GetDirectNode()) { + const auto &op_desc = n->GetOpDesc(); + op_desc->SetName("branche_0_" + n->GetName()); + + if (n->GetType() == DATA) { + GE_CHK_STATUS_RET(UpdataShapeToData(n, batch_shapes_[0]), "Update %s failed", branch->GetName().c_str()); + } + } + + return PostProcSubgraph(graph); +} + +/// +/// @ingroup ge +/// @brief Assign parent index for branches. +/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { + auto func_desc = case_node_->GetOpDesc(); + auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); + if (post_func == nullptr) { + GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), + case_node_->GetType().c_str()); + return FAILED; + } + + for (const auto &name : func_desc->GetSubgraphInstanceNames()) { + const auto &subgraph = graph->GetSubgraph(name); + if (subgraph == nullptr) { + GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str()); + return FAILED; + } + + std::string subgraph_name; + GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name), + "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); + + auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); + auto ret = post_func(subgraph_name, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), + case_node_->GetName().c_str(), case_node_->GetType().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Remove subgraph suspend output anchor. +/// @param [in] ComputeGraphPtr &graph: Parent compute graph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { + const auto &func_desc = case_node_->GetOpDesc(); + uint32_t unused_num = 0; + uint32_t output_num = func_desc->GetOutputsSize(); + for (size_t i = 0; i < output_num; ++i) { + bool is_unused_tensor = true; + for (const auto &item : all_branch_output_) { + const auto &netoutput = item.second; + GE_CHECK_NOTNULL(netoutput); + const auto in_anchor = netoutput->GetInDataAnchor(i); + if (in_anchor->GetPeerOutAnchor() != nullptr) { + is_unused_tensor = false; + break; + } + } + + if (is_unused_tensor) { + unused_num++; + continue; + } + + GE_CHK_STATUS_RET(UpdateOutputTensor(i, unused_num), "Graph:%s Update output failed", graph->GetName().c_str()); + } + + if (unused_num == 0) { + return SUCCESS; + } + + GE_CHK_STATUS_RET(NodeUtils::RemoveOutputAnchor(case_node_, output_num - unused_num), "Remove output failed"); + for (const auto &item : all_branch_output_) { + GE_CHK_STATUS_RET(NodeUtils::RemoveInputAnchor(item.second, output_num - unused_num), "Remove input failed"); + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Update subgraph suspend output tensor. +/// @param [in] parent_index: parent index for check. +/// @param [in] unused_num: total unused tensor. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) { + if (unused_num == 0) { + return SUCCESS; + } + + uint32_t update_index = parent_index - unused_num; + for (const auto &item : all_branch_output_) { + const auto &node = item.second; + const auto &new_anchor = node->GetInDataAnchor(update_index); + const auto &old_anchor = node->GetInDataAnchor(parent_index); + const auto &out_anchor = old_anchor->GetPeerOutAnchor(); + const auto &out_node = out_anchor->GetOwnerNode(); + + const auto &op_desc = node->GetOpDesc(); + (void)op_desc->UpdateInputDesc(update_index, op_desc->GetInputDesc(parent_index)); + + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, new_anchor), "Add edge failed"); + GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u", + case_node_->GetName().c_str(), out_node->GetName().c_str(), parent_index, update_index); + + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed"); + GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), out_node->GetName().c_str()); + } + + const auto &new_anchor = case_node_->GetOutDataAnchor(update_index); + const auto &old_anchor = case_node_->GetOutDataAnchor(parent_index); + for (const auto in_anchor : old_anchor->GetPeerInDataAnchors()) { + const auto &in_node = in_anchor->GetOwnerNode(); + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(old_anchor, in_anchor), "Remove edge failed"); + GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), in_node->GetName().c_str()); + + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(new_anchor, in_anchor), "Add edge failed"); + GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u", + case_node_->GetName().c_str(), in_node->GetName().c_str(), parent_index, update_index); + } + + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/multi_batch_clone_pass.h b/src/ge/graph/passes/multi_batch_clone_pass.h new file mode 100644 index 00000000..1da08e78 --- /dev/null +++ b/src/ge/graph/passes/multi_batch_clone_pass.h @@ -0,0 +1,155 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_MULTI_BATCH_CLONE_PASS_H_ +#define GE_GRAPH_PASSES_MULTI_BATCH_CLONE_PASS_H_ + +#include +#include +#include + +#include "inc/graph_pass.h" + +namespace ge { +class MultiBatchClonePass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + /// + /// @ingroup ge + /// @brief Collect input output node from original graph. + /// @param [in] const ComputeGraphPtr &graph: original graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status CollectIoNodes(const ComputeGraphPtr &graph); + + /// + /// @ingroup ge + /// @brief Create nodes for root graph. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status CreateRootGraph(const ComputeGraphPtr &graph); + + /// + /// @ingroup ge + /// @brief Create index node for root graph. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status CreateIndexNode(const ComputeGraphPtr &graph); + + /// + /// @ingroup ge + /// @brief Create input node for root graph. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status CreateInputNode(const ComputeGraphPtr &graph); + + /// + /// @ingroup ge + /// @brief Create Const node for root graph. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status CreateConstNode(const ComputeGraphPtr &graph); + + /// + /// @ingroup ge + /// @brief Create output node for root graph. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status CreateOutputNode(const ComputeGraphPtr &graph); + + /// + /// @ingroup ge + /// @brief Set max shape to Data node in root graph. + /// @param [in] const NodePtr &data: data in Root/Case graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status SetMaxShapeToData(const NodePtr &data); + + /// + /// @ingroup ge + /// @brief Set shape to Data node in branch. + /// @param [in] const NodePtr &data: data in branch. + /// @param [in] const std::vector &shapes: dims of shape. + /// @return 0: SUCCESS / others: FAILED + /// + Status UpdataShapeToData(const NodePtr &data, const std::vector &shapes); + + /// + /// @ingroup ge + /// @brief Set max shape to Data node in root graph. + /// @param [in] const std::vector &shapes: dims of shape. + /// @param [in] const NodePtr &data: data in Root/Case graph. + /// @param [in] GeShape &data_shape: dims of data node. + /// @return 0: SUCCESS / others: FAILED + /// + Status SetShapeToData(const std::vector &shapes, const NodePtr &data, GeShape &data_shape); + + /// + /// @ingroup ge + /// @brief Create nodes for root graph. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @param [in] const ComputeGraphPtr &branch: original graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch); + + /// + /// @ingroup ge + /// @brief Assign parent index for branches. + /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status PostProcSubgraph(const ComputeGraphPtr &graph); + + /// + /// @ingroup ge + /// @brief Remove subgraph supend output anchor. + /// @param [in] ComputeGraphPtr &graph: Parent compute graph. + /// @return 0: SUCCESS / others: FAILED + /// + Status PruneDirectOutput(const ComputeGraphPtr &graph); + + /// + /// @ingroup ge + /// @brief Update subgraph suspend output tensor. + /// @param [in] parent_index: parent index for check. + /// @param [in] unused_num: total unused tensor. + /// @return 0: SUCCESS / others: FAILED + /// + Status UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num); + + std::string session_graph_id_; + std::vector> batch_shapes_; + + std::vector all_data_nodes_; + std::vector all_const_nodes_; + std::vector all_output_nodes_; + + std::map direct_output_; + std::map all_branch_output_; + + NodePtr case_node_; + NodePtr index_node_; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_MULTI_BATCH_CLONE_PASS_H_ diff --git a/src/ge/graph/passes/multi_batch_pass.cc b/src/ge/graph/passes/multi_batch_pass.cc index 7d484a25..32152a6f 100644 --- a/src/ge/graph/passes/multi_batch_pass.cc +++ b/src/ge/graph/passes/multi_batch_pass.cc @@ -18,15 +18,8 @@ #include #include -#include - #include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" #include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" using std::string; @@ -55,7 +48,10 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGE(FAILED, "Get dynamic type failed."); return FAILED; } - + if (GetUserDesignateShape() != SUCCESS) { + GELOGE(FAILED, "Get user designate shape failed."); + return FAILED; + } std::vector> batch_shape; vector> combined_batch; if (!CheckSwitchN(batch_shape, combined_batch)) { @@ -63,7 +59,14 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { return FAILED; } - FindSwitchOutNodes(batch_shape.size()); + if (attach_label_only_) { + return AttachLabelOnly(batch_shape.size()); + } + + if (FindSwitchOutNodes(batch_shape.size()) != SUCCESS) { + GELOGE(FAILED, "Find SwitchN out nodes failed."); + return FAILED; + } if (ReplaceSwitchN(graph, pred_value, batch_shape, combined_batch) != SUCCESS) { GELOGE(FAILED, "Replace SwitchN nodes failed."); @@ -81,6 +84,17 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { return SUCCESS; } +/// +/// @brief Clear Status +/// @return +/// +Status MultiBatchPass::ClearStatus() { + switch_n_nodes_.clear(); + bypass_nodes_.clear(); + batch_head_nodes_.clear(); + return SUCCESS; +} + /// /// @brief Replace & Combine SwitchN nodes /// @param [in] graph @@ -159,6 +173,40 @@ Status MultiBatchPass::GetDynamicType() { return SUCCESS; } +/// +/// @brief Get user designate shape order. eg{"data","label","mask"} +/// @return Status +/// +Status MultiBatchPass::GetUserDesignateShape() { + data_name_order_.clear(); + bool first_check = true; + for (const auto &switchn : switch_n_nodes_) { + auto switchn_desc = switchn->GetOpDesc(); + GE_CHECK_NOTNULL(switchn_desc); + vector cur_switchn_data_name_order; + if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) { + GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str()); + return FAILED; + } + if (first_check) { + data_name_order_ = cur_switchn_data_name_order; + first_check = false; + } else { + if (data_name_order_ != cur_switchn_data_name_order) { + GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", + switchn->GetName().c_str()); + return FAILED; + } + } + } + if (data_name_order_.empty()) { + GELOGE(FAILED, "user shape order can not be empty"); + return FAILED; + } + + return SUCCESS; +} + /// /// @brief Check SwitchN nodes /// @param [out] batch_shape @@ -259,22 +307,41 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector> &b /// @param [in] batch_num /// @return void /// -void MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { +Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { std::vector output_nodes; for (uint32_t i = 0; i < batch_num; i++) { output_nodes.clear(); for (const NodePtr &node : switch_n_nodes_) { // idx is promised to be valid OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i); - GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor); + GE_CHECK_NOTNULL(out_data_anchor); for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - output_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); + auto out_node = peer_in_anchor->GetOwnerNode(); + if (out_node->GetType() != IDENTITY || !out_node->GetOutDataNodes().empty()) { + output_nodes.emplace_back(out_node); + continue; + } + bypass_nodes_.emplace_back(out_node); + if (GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(), + out_node->GetName().c_str()); + return FAILED; + } + for (auto &identity_out_node : out_node->GetOutControlNodes()) { + output_nodes.emplace_back(identity_out_node); + if (GraphUtils::RemoveEdge(out_node->GetOutControlAnchor(), identity_out_node->GetInControlAnchor()) != + GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(), + out_node->GetName().c_str()); + return FAILED; + } + } } } batch_head_nodes_.emplace_back(output_nodes); } - return; + return SUCCESS; } /// @@ -355,7 +422,6 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s } } } - return true; } @@ -404,6 +470,10 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const GELOGE(FAILED, "Set attr ATTR_DYNAMIC_TYPE failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } + if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) { + GELOGE(FAILED, "Set attr ATTR_USER_DESIGNEATE_SHAPE_ORDER failed, StreamSwitchN:%s.", name.c_str()); + return nullptr; + } for (uint32_t i = 0; i < batch_num; i++) { const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) { @@ -468,6 +538,7 @@ Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr } } } + GE_CHK_STATUS_RET(MoveCtrlEdges(switch_n_node, switch_case), "Move ctrl edges failed."); bypass_nodes_.emplace_back(switch_n_node); GELOGI("Bypass SwitchN node %s success.", switch_n_node->GetName().c_str()); @@ -495,7 +566,7 @@ Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) { stream_label_list.emplace_back(stream_label); } - return SetActiveLabelList(switch_case_node, stream_label_list); + return switch_case_node == nullptr ? SUCCESS : SetActiveLabelList(switch_case_node, stream_label_list); } /// @@ -595,4 +666,53 @@ Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string & return SUCCESS; } + +/// +/// @brief move edges from old_node to new_node +/// @param [in] old_node +/// @param [in] new_node +/// @return Status +/// +Status MultiBatchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) { + if (old_node == new_node) { + return SUCCESS; + } + for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()), + "Merge remove in ctrl edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()), + "StreamMerge add in ctrl edge failed."); + } + + for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), + "Merge remove out ctrl edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), + "StreamMerge add out ctrl edge failed."); + } + return SUCCESS; +} + +/// +/// @brief attach stream_label & batch_label without change structure of graph +/// @param [in] batch_num +/// @return void +/// +Status MultiBatchPass::AttachLabelOnly(uint32_t batch_num) { + std::vector output_nodes; + for (uint32_t i = 0; i < batch_num; i++) { + output_nodes.clear(); + for (const NodePtr &node : switch_n_nodes_) { + // idx is promised to be valid + OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i); + GE_CHECK_NOTNULL(out_data_anchor); + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + output_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); + } + } + batch_head_nodes_.emplace_back(output_nodes); + } + + return AttachLabel(nullptr); +} } // namespace ge diff --git a/src/ge/graph/passes/multi_batch_pass.h b/src/ge/graph/passes/multi_batch_pass.h index 8f14ec0a..1806229f 100644 --- a/src/ge/graph/passes/multi_batch_pass.h +++ b/src/ge/graph/passes/multi_batch_pass.h @@ -25,7 +25,10 @@ namespace ge { class MultiBatchPass : public GraphPass { public: - Status Run(ComputeGraphPtr graph); + explicit MultiBatchPass(bool attach_label_only = false) : attach_label_only_(attach_label_only) {} + ~MultiBatchPass() override = default; + Status Run(ComputeGraphPtr graph) override; + Status ClearStatus() override; private: Status FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value); @@ -33,7 +36,7 @@ class MultiBatchPass : public GraphPass { bool CheckSwitchN(std::vector> &batch_shape, std::vector> &combined_batch); bool GetBatchInfo(uint32_t batch_num, std::vector> &batch_shape, std::vector> &combined_batch); - void FindSwitchOutNodes(uint32_t batch_num); + Status FindSwitchOutNodes(uint32_t batch_num); Status ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, const std::vector> &batch_shape, const std::vector> &combined_batch); @@ -46,11 +49,16 @@ class MultiBatchPass : public GraphPass { Status AttachLabel(const NodePtr &switch_case_node); Status AttachBatchLabel(uint32_t batch_idx); Status AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label); + Status MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node); + Status AttachLabelOnly(uint32_t batch_num); + Status GetUserDesignateShape(); std::vector switch_n_nodes_; std::vector bypass_nodes_; std::vector> batch_head_nodes_; + std::vector data_name_order_; int32_t dynamic_type_ = 0; + bool attach_label_only_; }; } // namespace ge #endif // GE_GRAPH_PASSES_MULTI_BATCH_PASS_H_ diff --git a/src/ge/graph/passes/net_output_pass.cc b/src/ge/graph/passes/net_output_pass.cc index dd17f99c..f9c3835f 100644 --- a/src/ge/graph/passes/net_output_pass.cc +++ b/src/ge/graph/passes/net_output_pass.cc @@ -37,6 +37,9 @@ static std::map output_type_str_to_datatype = { {"UINT16", ge::DT_UINT16}, {"UINT8", ge::DT_UINT8}, {"INT32", ge::DT_INT32}, {"INT64", ge::DT_INT64}, {"UINT32", ge::DT_UINT32}, {"UINT64", ge::DT_UINT64}, {"DOUBLE", ge::DT_DOUBLE}}; +// the size of user defined output datatype or format string after split by ":". +const size_t kUserDefinedElementCount = 2; + Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, std::map &retval_node_index_map) { GE_CHECK_NOTNULL(node); @@ -552,27 +555,43 @@ void NetOutputPass::AddInOutForNetOutputOp(const ComputeGraphPtr &graph, OpDescP net_output_desc->SetIsInputConst(is_input_const); } -bool NeedUpdateOutputByOutputTypeParm(std::string &output_type, NodePtr &src_node, uint32_t src_index, +bool NeedUpdateOutputByOutputTypeParm(std::string &output_type, OpDescPtr &op_desc, uint32_t &src_index, ge::DataType &dt) { if (output_type_str_to_datatype.find(output_type) != output_type_str_to_datatype.end()) { dt = output_type_str_to_datatype[output_type]; return true; } - auto op_desc = src_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - vector output_data_type_vec; - vector index_vec; - if ((ge::AttrUtils::GetListDataType(op_desc, "_output_dt_list", output_data_type_vec)) && - (ge::AttrUtils::GetListInt(op_desc, "_output_dt_index", index_vec))) { - if (output_data_type_vec.size() != index_vec.size()) { - GELOGW("output_dt_list size is not match output_dt_index size"); - return false; - } - for (uint32_t i = 0; i < index_vec.size(); ++i) { - if (index_vec[i] == src_index) { - dt = output_data_type_vec[i]; - return true; + vector output_dt_str; + if (ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_data_type", output_dt_str)) { + for (const auto &dt_str : output_dt_str) { + vector dt_str_split = StringUtils::Split(dt_str, ':'); + if (dt_str_split.size() == kUserDefinedElementCount) { + if (dt_str_split[0] == to_string(src_index)) { + dt = TypeUtils::SerialStringToDataType(dt_str_split[1]); + return true; + } + } else { + GELOGW("The size of [%s] is not 2 after split.", dt_str.c_str()); + continue; + } + } + } + return false; +} + +bool NeedUpdateOutputFp16Nc1hwc0(OpDescPtr &op_desc, uint32_t &src_index) { + vector output_dt_str; + if (ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_fp16_5hd", output_dt_str)) { + for (const auto &dt_str : output_dt_str) { + vector dt_str_split = StringUtils::Split(dt_str, ':'); + if (dt_str_split.size() == kUserDefinedElementCount) { + if (dt_str_split[0] == to_string(src_index)) { + return true; + } + } else { + GELOGW("The size of [%s] is not 2 after split.", dt_str.c_str()); + continue; } } } @@ -601,9 +620,11 @@ Status NetOutputPass::SetUserDefDTypeAndFormatFromAtcParams(const NodePtr &outpu auto src_index = static_cast(peer_out->GetIdx()); auto src_node = peer_out->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); + OpDescPtr src_op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); // Update datatype - if (NeedUpdateOutputByOutputTypeParm(output_type, src_node, src_index, output_data_type)) { + if (NeedUpdateOutputByOutputTypeParm(output_type, src_op_desc, src_index, output_data_type)) { GELOGD("Add user-define datatype:%s to netoutput node.", TypeUtils::DataTypeToSerialString(output_data_type).c_str()); userdef_dtypes.push_back( @@ -611,10 +632,7 @@ Status NetOutputPass::SetUserDefDTypeAndFormatFromAtcParams(const NodePtr &outpu continue; } // Output_node is not set,check if is_output_adjust_hw_layout is set - OpDescPtr src_op_desc = src_node->GetOpDesc(); - GE_CHECK_NOTNULL(src_op_desc); - bool set_fp16_nc1hwc0 = false; - (void)AttrUtils::GetBool(src_op_desc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); + bool set_fp16_nc1hwc0 = NeedUpdateOutputFp16Nc1hwc0(src_op_desc, src_index); if (set_fp16_nc1hwc0) { // Set DT_FLOAT16 & FORMAT_NC1HWC0 userdef_dtypes.push_back(std::to_string(index).append(":").append(TypeUtils::DataTypeToSerialString(DT_FLOAT16))); diff --git a/src/ge/graph/passes/next_iteration_pass.cc b/src/ge/graph/passes/next_iteration_pass.cc index 12cde11e..73b3b77e 100644 --- a/src/ge/graph/passes/next_iteration_pass.cc +++ b/src/ge/graph/passes/next_iteration_pass.cc @@ -35,6 +35,10 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { return INTERNAL_ERROR; } } + if (GroupWithNoBatch(graph) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr."); + return INTERNAL_ERROR; + } if (FindWhileGroups() != SUCCESS) { GELOGE(INTERNAL_ERROR, "Find while groups failed."); @@ -69,17 +73,75 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { return FAILED; } - auto iter = loop_group_map_.find(frame_name); - if (iter == loop_group_map_.end()) { + std::string batch_label; + (void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label); + if (batch_label.empty()) { + auto frame_iter = frame_enter_map_.find(frame_name); + if (frame_iter == frame_enter_map_.end()) { + std::vector enter_nodes; + enter_nodes.emplace_back(enter_node); + frame_enter_map_[frame_name] = enter_nodes; + } else { + frame_iter->second.emplace_back(enter_node); + } + return SUCCESS; + } + + auto group_iter = loop_group_map_.find(frame_name); + if (group_iter == loop_group_map_.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); return FAILED; } loop_group->enter_nodes.emplace_back(enter_node); - loop_group_map_[frame_name] = loop_group; + loop_group_map_[frame_name][batch_label] = loop_group; } else { - iter->second->enter_nodes.emplace_back(enter_node); + auto batch_iter = group_iter->second.find(batch_label); + if (batch_iter == group_iter->second.end()) { + LoopCondGroupPtr loop_group = MakeShared(); + if (loop_group == nullptr) { + GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); + return FAILED; + } + loop_group->enter_nodes.emplace_back(enter_node); + group_iter->second[batch_label] = loop_group; + } else { + batch_iter->second->enter_nodes.emplace_back(enter_node); + } + } + + return SUCCESS; +} + +/// +/// @brief Group Enter nodes without batch_label attr +/// @param [in] compute_graph +/// @return Status +/// +Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) { + if (frame_enter_map_.empty()) { + GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str()); + return SUCCESS; + } + for (const auto &item : frame_enter_map_) { + const std::string &frame_name = item.first; + auto iter = loop_group_map_.find(frame_name); + if (iter == loop_group_map_.end()) { + LoopCondGroupPtr loop_group = MakeShared(); + if (loop_group == nullptr) { + GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); + return FAILED; + } + loop_group->enter_nodes = item.second; + loop_group_map_[frame_name][""] = loop_group; + } else { + for (auto &batch_item : iter->second) { + for (const auto &enter_node : item.second) { + batch_item.second->enter_nodes.emplace_back(enter_node); + } + } + } } return SUCCESS; @@ -92,39 +154,50 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { Status NextIterationPass::FindWhileGroups() { for (const auto &loop_group_iter : loop_group_map_) { const std::string &frame_name = loop_group_iter.first; - for (const auto &enter_node : loop_group_iter.second->enter_nodes) { - for (const auto &out_node : enter_node->GetOutAllNodes()) { - const std::string &type = out_node->GetType(); - if ((type != MERGE) && (type != REFMERGE)) { - continue; - } - - NodePtr next_node = nullptr; - if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s.", frame_name.c_str()); - return INTERNAL_ERROR; - } - loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); - - NodePtr switch_node = nullptr; - if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); - return INTERNAL_ERROR; - } - if (switch_node == nullptr) { - continue; - } - - NodePtr loop_cond = nullptr; - if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); - return INTERNAL_ERROR; - } - if (loop_group_iter.second->loop_cond == nullptr) { - loop_group_iter.second->loop_cond = loop_cond; - } else if (loop_group_iter.second->loop_cond != loop_cond) { - GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str()); - return FAILED; + for (const auto &batch_iter : loop_group_iter.second) { + const std::string &batch_label = batch_iter.first; + for (const auto &enter_node : batch_iter.second->enter_nodes) { + for (const auto &out_node : enter_node->GetOutAllNodes()) { + GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(), + frame_name.c_str(), batch_label.c_str()); + if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) { + continue; + } + std::string tmp_label; + GE_CHECK_NOTNULL(out_node->GetOpDesc()); + (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); + bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); + if (need_skip) { + continue; + } + + NodePtr next_node = nullptr; + if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get NextIteration node failed."); + return INTERNAL_ERROR; + } + batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); + + NodePtr switch_node = nullptr; + if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get Switch node failed."); + return INTERNAL_ERROR; + } + if (switch_node == nullptr) { + continue; + } + + NodePtr loop_cond = nullptr; + if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get LoopCond node failed."); + return INTERNAL_ERROR; + } + if (batch_iter.second->loop_cond == nullptr) { + batch_iter.second->loop_cond = loop_cond; + } else if (batch_iter.second->loop_cond != loop_cond) { + GELOGE(FAILED, "Multi LoopCond nodes exist."); + return FAILED; + } } } } @@ -145,18 +218,19 @@ bool NextIterationPass::VerifyWhileGroup() { GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); return false; } - - if (loop_group_iter.second->loop_cond == nullptr) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); - return false; - } - - for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { - if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", - frame_name.c_str()); + for (const auto &batch_iter : loop_group_iter.second) { + if (batch_iter.second->loop_cond == nullptr) { + GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); return false; } + + for (const auto &pair_iter : batch_iter.second->merge_next_pairs) { + if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { + GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", + frame_name.c_str()); + return false; + } + } } } @@ -170,52 +244,56 @@ bool NextIterationPass::VerifyWhileGroup() { /// Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { for (const auto &loop_cond_iter : loop_group_map_) { - const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); - GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); - - // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge - NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); - NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); - if ((enter_active == nullptr) || (next_active == nullptr)) { - GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); - return INTERNAL_ERROR; - } - - for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { - // Enter --> Active - if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); + for (const auto &batch_iter : loop_cond_iter.second) { + const std::string &cond_name = batch_iter.second->loop_cond->GetName(); + GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); + + // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge + NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); + NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); + if ((enter_active == nullptr) || (next_active == nullptr)) { + GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); return INTERNAL_ERROR; } - } - for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { - NodePtr merge_node = pair.first; - NodePtr next_node = pair.second; - // Active --> Merge - if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); - return INTERNAL_ERROR; + for (const auto &enter_node : batch_iter.second->enter_nodes) { + // Enter --> Active + if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != + GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge failed."); + return INTERNAL_ERROR; + } } - // NextIteration --> Active - if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); - return INTERNAL_ERROR; + for (const auto &pair : batch_iter.second->merge_next_pairs) { + NodePtr merge_node = pair.first; + NodePtr next_node = pair.second; + // Active --> Merge + if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != + GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge failed."); + return INTERNAL_ERROR; + } + + // NextIteration --> Active + if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge failed."); + return INTERNAL_ERROR; + } + + // break link between NextIteration and Merge + if (BreakNextIteration(next_node, merge_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); + return INTERNAL_ERROR; + } } - // break link between NextIteration and Merge - if (BreakNextIteration(next_node, merge_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); + if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || + (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { + GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); return INTERNAL_ERROR; } } - - if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || - (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { - GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); - return INTERNAL_ERROR; - } } return SUCCESS; @@ -282,11 +360,12 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & /// @param [in] node /// @param [in] target_type /// @param [in] is_input +/// @param [in] batch_label /// @param [out] target_node /// @return Status /// Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, - NodePtr &target_node) { + const std::string &batch_label, NodePtr &target_node) { if (node == nullptr) { GELOGE(PARAM_INVALID, "node is null."); return PARAM_INVALID; @@ -303,6 +382,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } for (const auto &tmp_node : nodes) { + std::string tmp_label; + (void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); + bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); + if (need_skip) { + continue; + } const std::string type = tmp_node->GetType(); if ((target_type == LOOPCOND) && (type == target_type)) { target_node = tmp_node; @@ -325,6 +410,7 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string /// @return SUCCESS /// Status NextIterationPass::ClearStatus() { + frame_enter_map_.clear(); loop_group_map_.clear(); return SUCCESS; } diff --git a/src/ge/graph/passes/next_iteration_pass.h b/src/ge/graph/passes/next_iteration_pass.h index 4cdf4b51..6f28a618 100644 --- a/src/ge/graph/passes/next_iteration_pass.h +++ b/src/ge/graph/passes/next_iteration_pass.h @@ -46,6 +46,13 @@ class NextIterationPass : public GraphPass { /// Status GroupEnterNode(const NodePtr &enter_node); + /// + /// @brief Group Enter nodes without batch_label attr + /// @param [in] compute_graph + /// @return Status + /// + Status GroupWithNoBatch(const ComputeGraphPtr &graph); + /// /// @brief Find while groups /// @return Status @@ -86,13 +93,17 @@ class NextIterationPass : public GraphPass { /// @param [in] node /// @param [in] target_type /// @param [in] is_input + /// @param [in] batch_label /// @param [out] target_node /// @return Status /// - Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); + Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, + const std::string &batch_label, NodePtr &target_node); - // map - std::unordered_map loop_group_map_; + // map> + std::unordered_map> frame_enter_map_; + // map> + std::unordered_map> loop_group_map_; }; } // namespace ge #endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ diff --git a/src/ge/graph/passes/replace_with_empty_const_pass.cc b/src/ge/graph/passes/replace_with_empty_const_pass.cc index cb35238b..212b1979 100644 --- a/src/ge/graph/passes/replace_with_empty_const_pass.cc +++ b/src/ge/graph/passes/replace_with_empty_const_pass.cc @@ -89,10 +89,13 @@ Status ReplaceWithEmptyConstPass::ReplaceWithEmptyConst(NodePtr &node_to_replace } // Repalce data anchors - if (GraphUtils::ReplaceNodeDataAnchors(const_node, node_to_replace, {}, shape_2_out_idx.second) != GRAPH_SUCCESS) { - GELOGE(FAILED, "[%s] ReplaceNodeAnchors failed.", node_to_replace->GetName().c_str()); - return FAILED; + for (const auto &anchor_idx : shape_2_out_idx.second) { + if (GraphUtils::ReplaceNodeDataAnchors(const_node, node_to_replace, {}, {anchor_idx}) != GRAPH_SUCCESS) { + GELOGE(FAILED, "[%s] ReplaceNodeAnchors failed.", node_to_replace->GetName().c_str()); + return FAILED; + } } + // Copy in control edge if (GraphUtils::CopyInCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "CopyInCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), diff --git a/src/ge/graph/passes/subexpression_migration_pass.cc b/src/ge/graph/passes/subexpression_migration_pass.cc new file mode 100644 index 00000000..cb09a743 --- /dev/null +++ b/src/ge/graph/passes/subexpression_migration_pass.cc @@ -0,0 +1,559 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "subexpression_migration_pass.h" + +#include "graph/utils/node_utils.h" +#include "ge_local_engine/engine/host_cpu_engine.h" +#include "graph/passes/folding_pass.h" + +namespace ge { +constexpr uint32_t kDataOutIndex = 0; +constexpr uint32_t kCaseInputBase = 1; +constexpr uint32_t kInvalidParent = 0x7fffffffU; +const std::set kTransOpTypes = {"Cast", "TransData", "Reshape", "BnHost"}; + +bool IsSameTensor(ConstGeTensorDescPtr src_tensor, ConstGeTensorDescPtr dst_tensor) { + if ((src_tensor == nullptr) && (dst_tensor == nullptr)) { + return true; + } + if ((src_tensor == nullptr) || (dst_tensor == nullptr)) { + return false; + } + + if ((src_tensor->GetDataType() != dst_tensor->GetDataType()) || + (src_tensor->GetFormat() != dst_tensor->GetFormat())) { + return false; + } + + const auto src_dims = src_tensor->GetShape().GetDims(); + const auto dst_dims = dst_tensor->GetShape().GetDims(); + if (src_dims != dst_dims) { + return false; + } + + const auto src_orig_dims = src_tensor->GetOriginShape().GetDims(); + const auto dst_orig_dims = dst_tensor->GetOriginShape().GetDims(); + if (src_orig_dims != dst_orig_dims) { + return false; + } + + return true; +} + +bool IsSameOpDesc(const OpDescPtr &src_desc, const OpDescPtr &dst_desc) { + if ((src_desc == nullptr) && (dst_desc == nullptr)) { + return true; + } + + if ((src_desc == nullptr) || (dst_desc == nullptr)) { + return false; + } + + if (src_desc->GetType() != dst_desc->GetType()) { + return false; + } + + if ((src_desc->GetInputsSize() != dst_desc->GetInputsSize()) || + (src_desc->GetOutputsSize() != dst_desc->GetOutputsSize())) { + return false; + } + + for (uint32_t i = 0; i < src_desc->GetInputsSize(); ++i) { + if (!IsSameTensor(src_desc->GetInputDescPtr(i), dst_desc->GetInputDescPtr(i))) { + return false; + } + } + + for (uint32_t i = 0; i < src_desc->GetOutputsSize(); ++i) { + if (!IsSameTensor(src_desc->GetOutputDescPtr(i), dst_desc->GetOutputDescPtr(i))) { + return false; + } + } + + return true; +} + +Status SubexpressionMigrationPass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + if (graph->GetParentGraph() != nullptr) { + GELOGD("Subgraph %s skip the SubexpressionMigrationPass", graph->GetName().c_str()); + return SUCCESS; + } + GELOGD("Begin to run Subexpression Migration on graph: %s", graph->GetName().c_str()); + + for (const auto &node : graph->GetDirectNode()) { + if (node->GetType() != CASE) { + continue; + } + + const auto &func_desc = node->GetOpDesc(); + if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { + GELOGD("Not multi-batch, Case: %s", node->GetName().c_str()); + continue; + } + + do { + migration_append_ = false; + map> graph_nodes; + if (ClassifyDataNodes(graph, func_desc, graph_nodes) != SUCCESS) { + return FAILED; + } + + if (graph_nodes.empty()) { + GELOGW("Graph: %s nodes is empty", graph->GetName().c_str()); + break; + } + + // {subgraph0, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} + // {subgraph1, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} + // {subgraph2, {{1, Data}, {2, Data}, {3, Data}, {4, Data}, ..., {n, Data}}} + const auto base_nodes = graph_nodes.begin()->second; // Need copy. + for (const auto &node_item : base_nodes) { + if (GraphNodeMigration(graph, node, graph_nodes, node_item.second, node_item.first) != SUCCESS) { + return FAILED; + } + } + } while (migration_append_); + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Get all Data nodes for all subgraph. +/// @param [in] graph: Root compute graph. +/// @param [in] func_desc: functional OpDesc of Case. +/// @param [out] graph_nodes: Data groups of subgraph. +/// @return 0: SUCCESS / others: FAILED +/// +Status SubexpressionMigrationPass::ClassifyDataNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, + map> &graph_nodes) { + for (const auto &name : func_desc->GetSubgraphInstanceNames()) { + const auto &subgraph = graph->GetSubgraph(name); + if (subgraph == nullptr) { + GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s", name.c_str()); + return GE_GRAPH_EMPTY_SUBGRAPH; + } + + auto &data_nodes = graph_nodes[subgraph]; + for (auto &data : subgraph->GetDirectNode()) { + if (data->GetType() != DATA) { + continue; + } + + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Parent index not found, name: %s", data->GetName().c_str()); + return FAILED; + } + + data_nodes[parent_index] = data; + GELOGD("Subgraph %s has %zu Data nodes", subgraph->GetName().c_str(), data_nodes.size()); + } + } + + for (const auto &data_nodes : graph_nodes) { + if (data_nodes.second.size() != graph_nodes.begin()->second.size()) { + GELOGE(FAILED, "Subgraph %s has invalid Data nodes[%zu != %zu]", data_nodes.first->GetName().c_str(), + data_nodes.second.size(), graph_nodes.begin()->second.size()); + return FAILED; + } + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Get all Data nodes for all subgraph. +/// @param [in] graph: Root compute graph. +/// @param [in] func_desc: functional OpDesc of Case. +/// @param [out] graph_nodes: Data groups of subgraph. +/// @return true: SUCCESS / false: FAILED +/// +bool SubexpressionMigrationPass::GetAssociatedNodes(const NodePtr &node, map &inputs, + map &outputs) { + for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { + outputs[i] = kInvalidParent; + } + + uint32_t out_index = 0; + for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + const auto &in_anchor = node->GetInDataAnchor(i); + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + inputs[i] = kInvalidParent; + continue; + } + + // Has none Data input node, Can not move to parent. + const auto &owner_node = out_anchor->GetOwnerNode(); + if (owner_node->GetType() != DATA) { + return false; + } + + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(owner_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + return false; + } + + // Input Data feed other Node, need add new Data. + inputs[i] = parent_index; + if ((out_index < outputs.size()) && (owner_node->GetOutDataNodesSize() == 1)) { + outputs[out_index] = parent_index; + ++out_index; + } + } + + return true; +} + +/// +/// @ingroup ge +/// @brief Get all Data nodes for all subgraph. +/// @param [in] graph_nodes: Data groups of subgraph. +/// @param [in] data_base: Data Node for migration. +/// @param [in] data_idx: Data groups of subgraph. +/// @param [in] data_idx: Data groups of subgraph. +/// @return true: Same / false: not same +/// +bool SubexpressionMigrationPass::IsParallelNodeSame(const map> &graph_nodes, + const NodePtr &base_node, uint32_t node_idx, uint32_t anchor_idx) { + auto it = graph_nodes.begin(); + for (++it; it != graph_nodes.end(); ++it) { + const auto &data_nodes = it->second; + auto data_it = data_nodes.find(node_idx); + if (data_it == data_nodes.end()) { + GELOGE(FAILED, "Data: %s not fount, index: %u", base_node->GetName().c_str(), node_idx); + return false; + } + + const auto &work_data = data_it->second; + const auto &out_anchor = work_data->GetOutDataAnchor(kDataOutIndex); + const auto &in_ahchors = out_anchor->GetPeerInDataAnchors(); + const auto &in_anchor = in_ahchors.at(anchor_idx); + if (in_anchor == nullptr) { + GELOGE(FAILED, "Data anchor size: %u, anchor size: %zu", anchor_idx, in_ahchors.size()); + return false; + } + + const auto &work_node = in_anchor->GetOwnerNode(); + if (work_node == nullptr) { + GELOGE(FAILED, "Data: %s not found, index: %u", base_node->GetName().c_str(), node_idx); + return false; + } + + if (!IsSameOpDesc(base_node->GetOpDesc(), work_node->GetOpDesc())) { + GELOGI("OpDesc diff: %s %s", base_node->GetName().c_str(), work_node->GetName().c_str()); + return false; + } + } + + return true; +} + +/// +/// @ingroup ge +/// @brief Migration subgraph Node to Root +/// @param [in] graph: Root compute graph. +/// @param [in] func_node: functional Node of Case. +/// @param [in] graph_nodes: Data groups of subgraph. +/// @param [in] data_base: Data Node for migration. +/// @param [in] data_idx: Data groups of subgraph. +/// @return 0: SUCCESS / others: FAILED +/// +Status SubexpressionMigrationPass::GraphNodeMigration(const ComputeGraphPtr &graph, const NodePtr &func_node, + map> &graph_nodes, + const NodePtr &base_data, uint32_t base_idx) { + bool can_extrapolation = false; + do { + can_extrapolation = false; + const auto out_anchor = base_data->GetOutDataAnchor(kDataOutIndex); + const auto in_anchors = out_anchor->GetPeerInDataAnchors(); + for (size_t i = 0; i < in_anchors.size(); ++i) { + const auto &in_anchor = in_anchors.at(i); + const auto &base_node = in_anchor->GetOwnerNode(); + if (kTransOpTypes.count(base_node->GetType()) == 0) { + continue; + } + + // Get associated Data, if Data feed other nodes, need append new Data. + map inputs; + map outputs; + if (!GetAssociatedNodes(base_node, inputs, outputs)) { + continue; + } + + if (!IsParallelNodeSame(graph_nodes, base_node, base_idx, i)) { + continue; + } + + GELOGI("Move to parent: %s", base_node->GetName().c_str()); + if (AppendParallelNode(graph_nodes, func_node, outputs) != SUCCESS) { + return FAILED; + } + + if (MoveNodeToParent(graph, func_node, graph_nodes, i, inputs, outputs) != SUCCESS) { + return FAILED; + } + can_extrapolation = true; + break; + } + } while (can_extrapolation); + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Append Input Tensor for functional node. +/// @param [in] graph_nodes: Data groups of subgraph. +/// @param [in] func_node: functional Node of Case. +/// @param [in] outputs: Parent index of Node output. +/// @return 0: SUCCESS / others: FAILED +/// +Status SubexpressionMigrationPass::AppendParallelNode(map> &graph_nodes, + const NodePtr &func_node, map &outputs) { + // If outputs index invalid, add Data and Input Tensor. + for (auto &item : outputs) { + if (item.second != kInvalidParent) { + continue; + } + + // Add Data to subgraph. + for (auto &groups : graph_nodes) { + const auto &subgraph = groups.first; + auto &data_nodes = groups.second; + + uint32_t data_index = data_nodes.size(); + item.second = data_index + kCaseInputBase; // Update to valid parent index. + std::string data_name = subgraph->GetName() + "_data_" + std::to_string(item.second); + + OpDescBuilder op_builder(data_name, DATA); + const OpDescPtr op_desc = op_builder.AddInput("x").AddOutput("y").Build(); + if (op_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Create multi-batch case desc failed"); + return OUT_OF_MEMORY; + } + + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index)) { + GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); + return FAILED; + } + + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, item.second)) { + GELOGE(FAILED, "Parent index not found, name: %s", op_desc->GetName().c_str()); + return FAILED; + } + + data_nodes[item.second] = subgraph->AddNode(op_desc); + } + + // Add InputTensor to functional Node. + NodeUtils::AppendInputAnchor(func_node, item.second + 1); + migration_append_ = true; + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Delete Node from all subgraph. +/// @param [in] graph_nodes: Data groups of subgraph. +/// @param [in] detach: Node will move to parent. +/// @param [in] outputs: Parent index of Node output. +/// @return 0: SUCCESS / others: FAILED +/// +Status SubexpressionMigrationPass::DetachParallelNode(const map &graph_datas, const NodePtr &detach, + const map &outputs) { + // Break Data and Move node. + for (const auto &in_anchor : detach->GetAllInDataAnchors()) { + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + continue; + } + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); + + const auto &owner_node = out_anchor->GetOwnerNode(); + GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), detach->GetName().c_str()); + } + + // Break Move and follow, Link Data and follow. + for (uint32_t i = 0; i < detach->GetAllOutDataAnchorsSize(); ++i) { + auto it_idx = outputs.find(i); + if (it_idx == outputs.end()) { + GELOGE(FAILED, "Node: %s parent index %u not found", detach->GetName().c_str(), i); + return FAILED; + } + + auto it_data = graph_datas.find(it_idx->second); + if (it_data == graph_datas.end()) { + GELOGE(FAILED, "Node: %s parent index %u not found", detach->GetName().c_str(), i); + return FAILED; + } + + const auto &data_node = it_data->second; + const auto &out_anchor = detach->GetOutDataAnchor(i); + + const auto &out_desc = detach->GetOpDesc()->GetOutputDesc(i); + const auto &data_desc = data_node->GetOpDesc(); + (void)data_desc->UpdateInputDesc(kDataOutIndex, out_desc); // Set Data Input to new connect Node. + (void)data_desc->UpdateOutputDesc(kDataOutIndex, out_desc); // Set Data Output to new connect Node. + + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + if (in_anchor == nullptr) { + continue; + } + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); + const auto &owner_node = in_anchor->GetOwnerNode(); + GELOGI("Remove Edge: %s %s", detach->GetName().c_str(), owner_node->GetName().c_str()); + + const auto &data_out_anchor = data_node->GetOutDataAnchor(kDataOutIndex); + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(data_out_anchor, in_anchor), "Add edge failed"); + GELOGI("Add Edge: %s %s", data_node->GetName().c_str(), owner_node->GetName().c_str()); + } + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Move Node to Parent Graph. +/// @param [in] graph: Parent compute graph. +/// @param [in] func_node: functional Node of Case. +/// @param [in] attach: Node will move to parent. +/// @param [in] inputs: Parent index of Node input. +/// @param [in] outputs: Parent index of Node output. +/// @return 0: SUCCESS / others: FAILED +/// +Status SubexpressionMigrationPass::AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, + const NodePtr &attach, const map &inputs, + const map &outputs) { + GE_CHECK_NOTNULL(attach); + for (uint32_t i = 0; i < attach->GetAllInDataAnchorsSize(); ++i) { + auto it_idx = inputs.find(i); + if (it_idx == inputs.end()) { + GELOGE(FAILED, "Node: %s parent index %u not found", attach->GetName().c_str(), i); + return FAILED; + } + if (it_idx->second == kInvalidParent) { // Not connnect, Skip. + continue; + } + + const auto &in_anchor = func_node->GetInDataAnchor(it_idx->second); + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, attach->GetInDataAnchor(i)), "Add edge failed"); + const auto &owner_node = out_anchor->GetOwnerNode(); + GELOGI("Add Edge: %s %s", owner_node->GetName().c_str(), attach->GetName().c_str()); + } + + for (uint32_t i = 0; i < attach->GetAllOutDataAnchorsSize(); ++i) { + auto it_idx = outputs.find(i); + if (it_idx == outputs.end()) { + return FAILED; + } + if (it_idx->second == kInvalidParent) { // Not connnect, Skip. + continue; + } + + const auto &out_desc = attach->GetOpDesc()->GetOutputDesc(i); + const auto &func_desc = func_node->GetOpDesc(); + (void)func_desc->UpdateInputDesc(it_idx->second, out_desc); // Set Data Input to new connect Node. + + const auto &in_anchor = func_node->GetInDataAnchor(it_idx->second); + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor != nullptr) { + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); + const auto &owner_node = out_anchor->GetOwnerNode(); + GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), func_node->GetName().c_str()); + } + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(attach->GetOutDataAnchor(i), in_anchor), "Add edge failed"); + GELOGI("Add Edge: %s %s", attach->GetName().c_str(), func_node->GetName().c_str()); + } + + (void)graph->AddNode(attach); + (void)attach->SetOwnerComputeGraph(graph); + GELOGI("Add Node: %s %s", graph->GetName().c_str(), attach->GetName().c_str()); + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Move node to Parent graph. +/// @param [in] graph: Root compute graph. +/// @param [in] func_node: functional Node of Case. +/// @param [in] graph_nodes: Data groups of subgraph. +/// @param [in] anchor_idx: anchor index of move Node. +/// @param [in] inputs: Parent index of Node input. +/// @param [in] outputs: Parent index of Node output. +/// @return 0: SUCCESS / others: FAILED +/// +Status SubexpressionMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph, const NodePtr &func_node, + const map> &graph_nodes, + uint32_t anchor_idx, const map &inputs, + const map &outputs) { + if (inputs.empty()) { + GELOGE(FAILED, "Graph: %s, inputs is empty", graph->GetName().c_str()); + return FAILED; + } + + NodePtr move_node; + uint32_t base_index = inputs.begin()->second; + for (auto &groups : graph_nodes) { + const auto &subgraph = groups.first; + const auto &subnodes = groups.second; + auto it = subnodes.find(base_index); + if (it == subnodes.end()) { + GELOGE(FAILED, "Graph: %s, Data: %u node not found", subgraph->GetName().c_str(), base_index); + return FAILED; + } + + const auto &base_data = it->second; + const auto &out_anchor = base_data->GetOutDataAnchor(kDataOutIndex); + const auto &in_anchors = out_anchor->GetPeerInDataAnchors(); + const auto &in_anchor = in_anchors.at(anchor_idx); + if (in_anchor == nullptr) { + GELOGE(FAILED, "Data anchor index: %u, anchor size: %zu", anchor_idx, in_anchors.size()); + return FAILED; + } + + move_node = in_anchor->GetOwnerNode(); + if (move_node == nullptr) { + GELOGE(FAILED, "Data: %s not found, index: %u", base_data->GetName().c_str(), base_index); + return FAILED; + } + + if (DetachParallelNode(subnodes, move_node, outputs) != SUCCESS) { + GELOGE(FAILED, "Data: %s not found, index: %u", base_data->GetName().c_str(), base_index); + return FAILED; + } + + GE_CHK_GRAPH_STATUS_RET(subgraph->RemoveNode(move_node), "Remove node failed"); + GELOGI("Remove Node: %s %s", subgraph->GetName().c_str(), move_node->GetName().c_str()); + } + + if (AttachParallelNode(graph, func_node, move_node, inputs, outputs) != SUCCESS) { + return FAILED; + } + + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/subexpression_migration_pass.h b/src/ge/graph/passes/subexpression_migration_pass.h new file mode 100644 index 00000000..ac750725 --- /dev/null +++ b/src/ge/graph/passes/subexpression_migration_pass.h @@ -0,0 +1,137 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_SUBEXPRESSION_MIGRATION_H_ +#define GE_COMMON_SUBEXPRESSION_MIGRATION_H_ + +#include "graph/types.h" +#include "inc/graph_pass.h" + +#include +#include +#include +#include + +using std::map; +using std::set; + +namespace ge { +class SubexpressionMigrationPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + + private: + /// + /// @ingroup ge + /// @brief Get all Data nodes for all subgraph. + /// @param [in] graph: Root compute graph. + /// @param [in] func_desc: functional OpDesc of Case. + /// @param [out] graph_nodes: Data groups of subgraph. + /// @return 0: SUCCESS / others: FAILED + /// + Status ClassifyDataNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, + map> &graph_nodes); + + /// + /// @ingroup ge + /// @brief Get all Data nodes for all subgraph. + /// @param [in] graph: Root compute graph. + /// @param [in] func_desc: functional OpDesc of Case. + /// @param [out] graph_nodes: Data groups of subgraph. + /// @return true: SUCCESS / false: FAILED + /// + bool GetAssociatedNodes(const NodePtr &node, map &inputs, map &outputs); + + /// + /// @ingroup ge + /// @brief Get all Data nodes for all subgraph. + /// @param [in] graph_nodes: Data groups of subgraph. + /// @param [in] data_base: Data Node for migration. + /// @param [in] data_idx: Data groups of subgraph. + /// @param [in] data_idx: Data groups of subgraph. + /// @return true: Same / false: not same + /// + bool IsParallelNodeSame(const map> &graph_nodes, const NodePtr &base_node, + uint32_t base_idx, uint32_t anchor_idx); + + /// + /// @ingroup ge + /// @brief Migration subgraph Node to Root + /// @param [in] graph: Root compute graph. + /// @param [in] func_node: functional Node of Case. + /// @param [in] graph_nodes: Data groups of subgraph. + /// @param [in] data_base: Data Node for migration. + /// @param [in] data_idx: Data groups of subgraph. + /// @return 0: SUCCESS / others: FAILED + /// + Status GraphNodeMigration(const ComputeGraphPtr &graph, const NodePtr &func_node, + map> &graph_nodes, const NodePtr &data_base, + uint32_t data_idx); + + /// + /// @ingroup ge + /// @brief Move node to Parent graph. + /// @param [in] graph: Root compute graph. + /// @param [in] func_node: functional Node of Case. + /// @param [in] graph_nodes: Data groups of subgraph. + /// @param [in] anchor_idx: anchor index of move Node. + /// @param [in] inputs: Parent index of Node input. + /// @param [in] outputs: Parent index of Node output. + /// @return 0: SUCCESS / others: FAILED + /// + Status MoveNodeToParent(const ComputeGraphPtr &graph, const NodePtr &func_node, + const map> &graph_nodes, uint32_t anchor_idx, + const map &inputs, const map &outputs); + + /// + /// @ingroup ge + /// @brief Append Input Tensor for functional node. + /// @param [in] graph_nodes: Data groups of subgraph. + /// @param [in] func_node: functional Node of Case. + /// @param [in] outputs: Parent index of Node output. + /// @return 0: SUCCESS / others: FAILED + /// + Status AppendParallelNode(map> &graph_nodes, const NodePtr &func_node, + map &outputs); + + /// + /// @ingroup ge + /// @brief Delete Node from all subgraph. + /// @param [in] graph_nodes: Data groups of subgraph. + /// @param [in] detach: Node will move to parent. + /// @param [in] outputs: Parent index of Node output. + /// @return 0: SUCCESS / others: FAILED + /// + Status DetachParallelNode(const map &graph_datas, const NodePtr &detach, + const map &outputs); + + /// + /// @ingroup ge + /// @brief Move Node to Parent Graph. + /// @param [in] graph: Parent compute graph. + /// @param [in] func_node: functional Node of Case. + /// @param [in] attach: Node will move to parent. + /// @param [in] inputs: Parent index of Node input. + /// @param [in] outputs: Parent index of Node output. + /// @return 0: SUCCESS / others: FAILED + /// + Status AttachParallelNode(const ComputeGraphPtr &graph, const NodePtr &func_node, const NodePtr &attach, + const map &inputs, const map &outputs); + + bool migration_append_{false}; +}; +} // namespace ge +#endif // GE_COMMON_SUBEXPRESSION_MIGRATION_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/subgraph_pass.cc b/src/ge/graph/passes/subgraph_pass.cc index 80ce995a..fbf444fb 100644 --- a/src/ge/graph/passes/subgraph_pass.cc +++ b/src/ge/graph/passes/subgraph_pass.cc @@ -98,21 +98,14 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP return FAILED; } - NodePtr in_node = NodeUtils::GetParentInput(node); - GE_CHECK_NOTNULL(in_node); // Subgraph Data Node, check for constant input. std::string const_type; - if (!NodeUtils::GetConstOpType(in_node, const_type)) { + if (!NodeUtils::GetConstOpType(node, const_type)) { return SUCCESS; } const NodePtr &parent_node = graph->GetParentNode(); - if (kWhileOpTypes.count(parent_node->GetType()) == 0) { - if (!AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { - GELOGE(FAILED, "Set attr PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); - return FAILED; - } - } else { + if (kWhileOpTypes.count(parent_node->GetType()) != 0) { // Constant input to While need memcpy. const ComputeGraphPtr &parent_graph = parent_node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(parent_graph); @@ -211,6 +204,10 @@ Status SubgraphPass::WhileBodySubgraph(const ComputeGraphPtr &graph, const NodeP GELOGE(FAILED, "while_body of %s is NULL.", node->GetName().c_str()); return FAILED; } + if (GraphUtils::IsUnknownShapeGraph(while_body)) { + GELOGI("Unknown shape while_body graph %s no need to insert memcpy.", while_body->GetName().c_str()); + return SUCCESS; + } std::vector data_nodes; std::set bypass_index; @@ -258,7 +255,7 @@ Status SubgraphPass::InsertInputMemcpy(const ComputeGraphPtr &graph, const std:: } std::string in_name = graph->GetName() + "_input_Memcpy"; - OpDescBuilder in_builder(in_name, MEMCPYASYNC); + OpDescBuilder in_builder(in_name, IDENTITY); for (size_t i = 0; i < data_nodes.size(); i++) { // Data node has and only has one output in_builder.AddInput("x" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)) @@ -300,7 +297,7 @@ Status SubgraphPass::InsertOutputMemcpy(const ComputeGraphPtr &graph, const Node } std::string out_name = graph->GetName() + "_output_Memcpy"; - OpDescBuilder out_builder(out_name, MEMCPYASYNC); + OpDescBuilder out_builder(out_name, IDENTITY); for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { if (bypass_index.count(i) == 0) { out_builder.AddInput("x" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)) @@ -438,12 +435,13 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat const std::vector &in_anchors, const std::string &name) { GE_CHECK_NOTNULL(out_anchor); NodePtr in_node = out_anchor->GetOwnerNode(); - OpDescBuilder op_desc_builder(name, MEMCPYASYNC); + OpDescBuilder op_desc_builder(name, IDENTITY); OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) .Build(); + (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); + GELOGE(FAILED, "Insert IDENTITY node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); return FAILED; } diff --git a/src/ge/graph/passes/switch_fusion_pass.cc b/src/ge/graph/passes/switch_fusion_pass.cc deleted file mode 100644 index f475d857..00000000 --- a/src/ge/graph/passes/switch_fusion_pass.cc +++ /dev/null @@ -1,249 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "switch_fusion_pass.h" -#include -#include -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils.h" -namespace ge { -namespace { -const int kSwitchDataInputIdx = 0; -const int kSwitchCondInputIdx = 1; -const int kSwitchFalseOutIdx = 0; -const int kSwitchTrueOutIdx = 1; -int GetSwitchOutDataIdx(const string fusion_group_id) { return std::stoi(fusion_group_id.substr(0)); } -} // namespace - -Status SwitchFusionPass::Run(NodePtr &node) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (node->GetOpDesc()->GetType() != SWITCH && node->GetOpDesc()->GetType() != REFSWITCH) { - return SUCCESS; - } - GELOGD("Switch fusion pass in.Current switch node name is %s", node->GetName().c_str()); - // 1. find cond input - auto switch_in_cond_anchor = node->GetInDataAnchor(kSwitchCondInputIdx); - if (switch_in_cond_anchor->GetPeerOutAnchor() == nullptr) { - GELOGI("Switch %s in condition peer out anchor is null.", node->GetName().c_str()); - return FAILED; - } - auto switch_cond_in_node = switch_in_cond_anchor->GetPeerOutAnchor()->GetOwnerNode(); - GELOGD("Switch %s cond in data node is %s.", node->GetName().c_str(), switch_cond_in_node->GetName().c_str()); - if (switch_cond_in_node->GetOutDataNodesSize() == 1) { - GELOGI("This condition only has one switch, no need fusion."); - return SUCCESS; - } - // 2. find other switch with same condition - for (const auto out_data_node : switch_cond_in_node->GetOutDataNodes()) { - if (out_data_node->GetType() == SWITCH || out_data_node->GetType() == REFSWITCH) { - // 2.1 collect switch node can be fused with same cond_in_node - auto true_out_anchor = out_data_node->GetOutDataAnchor(kSwitchTrueOutIdx); - auto false_out_anchor = out_data_node->GetOutDataAnchor(kSwitchFalseOutIdx); - int branch_idx = true_out_anchor == nullptr ? kSwitchFalseOutIdx : kSwitchTrueOutIdx; - if (out_data_node->GetOutDataAnchor(branch_idx)->GetPeerInDataNodesSize() > 1) { - GELOGI("Current switch node %s has more than one output, need go to switch split first.", - out_data_node->GetName().c_str()); - continue; - } - string fusion_road_id; - fusion_road_id = GetFusionRoadId(std::to_string(branch_idx), out_data_node); - GELOGI("Switch node %s out idx %d, group_id is %s.", out_data_node->GetName().c_str(), branch_idx, - fusion_road_id.c_str()); - auto iter = switch_group_map_.find(fusion_road_id); - if (iter == switch_group_map_.end()) { - switch_group_map_.emplace(std::make_pair(fusion_road_id, std::set{out_data_node})); - } else { - // to avoid one cond node is also as data node - if (iter->second.count(out_data_node) == 0) { - iter->second.emplace(out_data_node); - } - } - } - } - // 3. fuse switch from different group - auto ret = FuseSwitchGroup(); - if (ret != SUCCESS) { - GELOGE(FAILED, "Fuse switch nodes with same final output to one failed."); - return ret; - } - return SUCCESS; -} -/* - * var1 ALLREDUCE/Cast var3 var1 var2 var3 ALLREDUCE/Cast - * \ / \ \ / \ | / \ / \ - * switch1 switch2 switch3 ======> AdamApplyOne / \--->switch1 - * \ | / \ / | - * AdamApplyOne / mul <--- identity - * \ / - * mul - */ -Status SwitchFusionPass::FuseSwitchGroup() { - for (auto &key_2_switch_group : switch_group_map_) { - if (key_2_switch_group.second.size() == 1) { - break; - } - // 1.Insert Identity node - NodePtr remain_switch = *key_2_switch_group.second.begin(); - auto switch_out_anchor_idx = GetSwitchOutDataIdx(key_2_switch_group.first); - auto identity_node = InsertIdentityNode(remain_switch, switch_out_anchor_idx); - if (identity_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Create Identity op %s fail.", identity_node->GetName().c_str()); - return FAILED; - } - // 2. Remove all switch nodes between data anchors. - string hccl_group_id; - for (const auto &switch_node : key_2_switch_group.second) { - GELOGI("Get corresponding SWITCH node is %s.Out data anchor idx is %d.", switch_node->GetName().c_str(), - switch_out_anchor_idx); - // get hccl group id for remain switch - if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { - GELOGI("Get hccl group id %s of switch node %s.", hccl_group_id.c_str(), switch_node->GetName().c_str()); - } - auto switch_peer_in_data_anchor = - switch_node->GetOutDataAnchor(switch_out_anchor_idx)->GetPeerInDataAnchors().at(0); - GE_RETURN_WITH_LOG_IF_ERROR(RemoveSwitchBetweenTwoNode(switch_out_anchor_idx, switch_node)); - GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(identity_node->GetOutControlAnchor(), - switch_peer_in_data_anchor->GetOwnerNode()->GetInControlAnchor()), - "Link control edge from identity %s to out node %s.", - identity_node->GetName().c_str(), - switch_peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - } - GELOGI("Start fusion switch nodes. Switch_nodes_set size is %d", key_2_switch_group.second.size()); - // 3.Fuse all switch to one, first is remain_switch - GE_RETURN_WITH_LOG_IF_ERROR(FuseSwitchNodesToOne(remain_switch, key_2_switch_group.second)); - if (!hccl_group_id.empty()) { - AttrUtils::SetStr(remain_switch->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); - GELOGI("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream node %s, value is %s.", remain_switch->GetName().c_str(), - hccl_group_id.c_str()); - } - // Link switch to identity - GraphUtils::AddEdge(remain_switch->GetOutDataAnchor(switch_out_anchor_idx), identity_node->GetInDataAnchor(0)); - } - return SUCCESS; -} -/* - * var1---- - * cond---- - * var2---- - */ -Status SwitchFusionPass::RemoveSwitchBetweenTwoNode(const int switch_out_anchor_idx, const NodePtr &switch_node) { - auto switch_in_data_anchor = switch_node->GetInDataAnchor(kSwitchDataInputIdx); - auto switch_in_cond_anchor = switch_node->GetInDataAnchor(kSwitchCondInputIdx); - // here we assume after switch split, one switch node only has one data output,so just get first is ok. - auto switch_peer_in_data_anchor = switch_node->GetOutDataAnchor(switch_out_anchor_idx)->GetPeerInDataAnchors().at(0); - // 2.1 unlink all data edge from switch to out_node - GE_RETURN_WITH_LOG_IF_ERROR( - GraphUtils::RemoveEdge(switch_node->GetOutDataAnchor(switch_out_anchor_idx), switch_peer_in_data_anchor), - "Remove edge from switch %s to out node %s.", switch_node->GetName().c_str(), - switch_peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - // 2.2 replace data edge from switch_data_in_node to switch_data_out_node - if (switch_in_data_anchor->GetPeerOutAnchor() == nullptr) { - GELOGI("Switch %s in data peer out anchor is null.", switch_node->GetName().c_str()); - return FAILED; - } - auto switch_in_node = switch_in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); - GELOGI("Switch %s in data node is %s.", switch_node->GetName().c_str(), switch_in_node->GetName().c_str()); - GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::ReplaceEdgeDst(switch_in_data_anchor->GetPeerOutAnchor(), - switch_in_data_anchor, switch_peer_in_data_anchor), - "ReplaceEdgeDst from switch_data_in_node %s to switch_out_node %s.", - switch_in_node->GetName().c_str(), - switch_peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - // 2.3 link control edge from switch_data_in_node to switch - GE_RETURN_WITH_LOG_IF_ERROR( - GraphUtils::AddEdge(switch_in_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), - "Link control edge from switch_data_in_node %s to switch node %s failed.", switch_in_node->GetName().c_str(), - switch_node->GetName().c_str()); - return SUCCESS; -} - -Status SwitchFusionPass::FuseSwitchNodesToOne(NodePtr &remain_switch, const std::set switch_nodes_set) { - auto iter = ++switch_nodes_set.begin(); - while (iter != switch_nodes_set.end()) { - GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::CopyInCtrlEdges(*iter, remain_switch), - "Copy in control edge from %s to %s failed.", (*iter)->GetName().c_str(), - remain_switch->GetName().c_str()); - GE_RETURN_WITH_LOG_IF_ERROR(NodeUtils::MoveOutputEdges(*iter, remain_switch), - "Move output edges from %s to %s failed.", (*iter)->GetName().c_str(), - remain_switch->GetName().c_str()); - if ((*iter)->GetOutDataNodesSize() == 0) { - auto ret = IsolateAndDeleteNode(const_cast(*iter), {}); - if (ret == SUCCESS) { - GELOGI("IsolateAndDeleteNode Switch node %s", (*iter)->GetName().c_str()); - } - } else { - GELOGI("Switch node %s has more than one out data nodes, keep it.", (*iter)->GetName().c_str()); - } - iter++; - } - // link data input for remain switch - auto cond_node = remain_switch->GetInDataAnchor(kSwitchCondInputIdx)->GetPeerOutAnchor()->GetOwnerNode(); - GELOGI("Get cond node %s of switch node %s.", cond_node->GetName().c_str(), remain_switch->GetName().c_str()); - GE_RETURN_WITH_LOG_IF_ERROR( - GraphUtils::AddEdge(cond_node->GetOutDataAnchor(0), remain_switch->GetInDataAnchor(kSwitchDataInputIdx)), - "Fail to add edge from cond_node %s to remain_switch %s.", cond_node->GetName().c_str(), - remain_switch->GetName().c_str()); - return SUCCESS; -} - -const string SwitchFusionPass::GetFusionRoadId(const string branch_id, const NodePtr &switch_node) { - std::deque queue; - queue.push_back(switch_node); - std::stringstream group_id; - group_id << branch_id; - - while (!queue.empty()) { - NodePtr node = queue.front(); - queue.pop_front(); - if (node->GetOutDataNodesSize() == 0) { - group_id << "-" << node->GetName(); - GELOGI("Switch node %s, group id is %s", switch_node->GetName().c_str(), group_id.str().c_str()); - return group_id.str(); - } - for (const auto &out_data_node : node->GetOutDataNodes()) { - if (out_data_node->GetType() == NETOUTPUT || out_data_node->GetType() == SWITCH || - out_data_node->GetType() == SWITCH) { - // if meet NETOUTPUT, it is the end of current ROAD - group_id << "-" << node->GetName(); - GELOGI("Switch node %s, group id is %s", switch_node->GetName().c_str(), group_id.str().c_str()); - return group_id.str(); - } - queue.emplace_back(out_data_node); - } - } - return group_id.str(); -} -NodePtr SwitchFusionPass::InsertIdentityNode(const NodePtr &remain_switch, const int out_data_anchor_idx) { - const std::string identity_name = remain_switch->GetOpDesc()->GetName() + "_" + IDENTITY; - ComputeGraphPtr graph = remain_switch->GetOwnerComputeGraph(); - auto data_desc = remain_switch->GetOpDesc()->GetOutputDesc(out_data_anchor_idx); - OpDescPtr op_desc = MakeShared(identity_name, IDENTITY); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create Identity op %s: create op_desc fail.", identity_name.c_str()); - return nullptr; - } - if ((op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS)) { - GELOGE(INTERNAL_ERROR, "Create Identity op %s: add input/output desc fail.", identity_name.c_str()); - return nullptr; - } - GELOGI("Create Identity op:%s.", identity_name.c_str()); - return graph->AddNode(op_desc); -} -} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/passes/switch_split_pass.cc b/src/ge/graph/passes/switch_split_pass.cc deleted file mode 100644 index 07f59d20..00000000 --- a/src/ge/graph/passes/switch_split_pass.cc +++ /dev/null @@ -1,145 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "switch_split_pass.h" -#include -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/util.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils.h" - -using namespace ge; -namespace { -const string output_false = "output_false"; -const string output_true = "output_true"; -string GetOutputDescName(const int idx) { return idx == 0 ? output_false : output_true; } -graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto src_data_in_nodes = src_node->GetInDataNodes(); - if (src_data_in_nodes.empty()) { - return GRAPH_SUCCESS; - } - for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { - auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); - auto ret = - GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", - in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), - src_node->GetName().c_str(), dst_node->GetName().c_str()); - return ret; - } - } - return GRAPH_SUCCESS; -} -NodePtr CreateSwitchFromOld(const int index, const NodePtr &old_switch, const OutDataAnchorPtr &out_data_anchor) { - auto graph = old_switch->GetOwnerComputeGraph(); - // 1. create new switch op desc - string new_switch_name = old_switch->GetName() + "_" + std::to_string(index); - auto new_switch_opdesc = MakeShared(new_switch_name, old_switch->GetType()); - if (new_switch_opdesc == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed to insert switch node, name %s", new_switch_name.c_str()); - return nullptr; - } - // 2. add input_desc & output_desc for new switch - Status ret; - for (const auto &in_data_anchor : old_switch->GetAllInDataAnchors()) { - auto input_desc = old_switch->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); - ret = new_switch_opdesc->AddInputDesc(in_data_anchor->GetIdx(), input_desc); - if (ret != SUCCESS) { - GELOGE(FAILED, "Add Input desc failed for new switch %s.", new_switch_name.c_str()); - return nullptr; - } - } - auto output_desc = old_switch->GetOpDesc()->GetOutputDesc(out_data_anchor->GetIdx()); - // we got out_data_anchor, another out_data_anchor is (1-idx), because idx is 0 or 1. - auto ret1 = new_switch_opdesc->AddOutputDesc(GetOutputDescName(1 - out_data_anchor->GetIdx()), output_desc); - auto ret2 = new_switch_opdesc->AddOutputDesc(GetOutputDescName(out_data_anchor->GetIdx()), output_desc); - if (ret1 != SUCCESS || ret2 != SUCCESS) { - GELOGE(FAILED, "Add Output desc failed for new switch %s.", new_switch_name.c_str()); - return nullptr; - } - GELOGI("Insert new switch node %s.", new_switch_name.c_str()); - return graph->AddNode(new_switch_opdesc); -} -} // namespace -namespace ge { -Status SwitchSplitPass::Run(NodePtr &node) { - // To handle one out data anchor with multi peer input data anchor - GE_CHECK_NOTNULL(node); - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() != SWITCH && op_desc->GetType() != REFSWITCH) { - return SUCCESS; - } - if (op_desc->GetName().find("apply_one_adam") == string::npos) { - // Currently for bert optimize, will fix later. - GELOGI("Current switch node name is %s, ignore it.", op_desc->GetName().c_str()); - return SUCCESS; - } - GELOGI("Switch split pass in. Current switch node name is %s", op_desc->GetName().c_str()); - int index = 0; - // 1. find all output - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - if (out_data_anchor->GetPeerInDataNodesSize() < 2) { - GELOGI("Switch node %s %d th out data anchor only has 1 peer_in_data_anchor.Ignore it.", node->GetName().c_str(), - out_data_anchor->GetIdx()); - continue; - } - for (const auto &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - NodePtr new_switch = CreateSwitchFromOld(index, node, out_data_anchor); - if (new_switch == nullptr) { - GELOGW("Insert switch node failed."); - return FAILED; - } - // 1.3 copy int/out edge from old switch to new switch - auto ret1 = CopyInDataEdges(node, new_switch); - auto ret2 = GraphUtils::CopyInCtrlEdges(node, new_switch); - auto ret3 = GraphUtils::CopyOutCtrlEdges(node, new_switch); - if (ret1 != GRAPH_SUCCESS || ret2 != GRAPH_SUCCESS || ret3 != GRAPH_SUCCESS) { - GELOGE(FAILED, "Copy edge from %s to %s failed.", node->GetName().c_str(), new_switch->GetName().c_str()); - return FAILED; - } - if (out_data_anchor->Unlink(peer_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Unlink from old switch %s out data anchor %d to peer in anchor failed.", - node->GetName().c_str(), out_data_anchor->GetIdx()); - } - auto ret4 = GraphUtils::AddEdge(new_switch->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor); - if (ret4 != GRAPH_SUCCESS) { - GELOGE(FAILED, "Replace out data edge from old switch %s to new switch %s failed.", node->GetName().c_str(), - new_switch->GetName().c_str()); - return FAILED; - } - AddRePassNode(new_switch); - index++; - } - } - // 2.isolate switch node with no data output - if (node->GetOutDataNodesSize() == 0) { - auto ret = IsolateAndDeleteNode(node, {}); - if (ret != SUCCESS) { - GELOGE(FAILED, "IsolateAndDelete switch node %s.", node->GetName().c_str()); - return FAILED; - } - GELOGI("IsolateAndDelete switch node %s.", node->GetName().c_str()); - } - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/switch_to_stream_switch_pass.cc b/src/ge/graph/passes/switch_to_stream_switch_pass.cc index ef8879dd..6c0d545d 100644 --- a/src/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/src/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -48,7 +48,7 @@ Status SwitchToStreamSwitchPass::Run(ComputeGraphPtr graph) { } /// -/// @brief Clear Status, used for subgraph pass +/// @brief Clear Status /// @return /// Status SwitchToStreamSwitchPass::ClearStatus() { diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc index 887079f8..9d0ac4d4 100644 --- a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc @@ -26,9 +26,7 @@ #include "types.h" namespace { -const int kTransOpOutIndex = 0; const std::set white_list_op{ge::TRANSPOSED, ge::RESHAPE, ge::REFORMAT, ge::CAST, ge::TRANSDATA}; -std::map precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; } // namespace namespace ge { Status TransOpSymmetryEliminationPass::Run(NodePtr &node) { @@ -98,7 +96,7 @@ bool TransOpSymmetryEliminationPass::CheckCanBeEliminated(const ge::NodePtr &src return false; } } - return CheckPrecisionLoss(src_node) && DescAreSymmetry(src_node, dst_node); + return TransOpUtil::CheckPrecisionLoss(src_node) && DescAreSymmetry(src_node, dst_node); } bool TransOpSymmetryEliminationPass::DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node) { @@ -135,22 +133,6 @@ bool TransOpSymmetryEliminationPass::DescAreSymmetry(const NodePtr &src_node, co return is_symmetry; } -bool TransOpSymmetryEliminationPass::CheckPrecisionLoss(const ge::NodePtr &src_node) { - auto idx = TransOpUtil::GetTransOpDataIndex(src_node); - auto input_desc = src_node->GetOpDesc()->GetInputDesc(idx); - auto output_desc = src_node->GetOpDesc()->GetOutputDesc(kTransOpOutIndex); - auto src_dtype = input_desc.GetDataType(); - auto dst_dtype = output_desc.GetDataType(); - auto iter = precision_loss_transfer_map.find(src_dtype); - if (iter != precision_loss_transfer_map.end() && iter->second == dst_dtype) { - GELOGW("Node %s transfer data type from %s to %s ,it will cause precision loss. ignore pass.", - src_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(src_dtype).c_str(), - TypeUtils::DataTypeToSerialString(dst_dtype).c_str()); - return false; - } - return true; -} - int TransOpSymmetryEliminationPass::GetUnknownDimsNum(const GeTensorDesc &node_desc) { // // unknown_dims_num != 0 , is dynamic shape @@ -246,22 +228,30 @@ Status TransOpSymmetryEliminationPass::EliminateTransOp(NodePtr &src_node, const } GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", dst_node->GetName().c_str()); // 6.If T1 has no data out, isolate and deleted it. - if (src_node->GetOutDataNodesSize() == 0) { + ret = RemoveTransOpWithoutOutput(pre_normal_node, src_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Isolate removed node: %s, type: %s failed", src_node->GetName().c_str(), src_node->GetType().c_str()); + return ret; + } + return SUCCESS; +} +Status TransOpSymmetryEliminationPass::RemoveTransOpWithoutOutput(NodePtr &pre_node, NodePtr &trans_node) { + if (trans_node->GetOutDataNodesSize() == 0) { // 6.1 Copy out control to pre normal node - ret = GraphUtils::CopyOutCtrlEdges(src_node, pre_normal_node); + Status ret = GraphUtils::CopyOutCtrlEdges(trans_node, pre_node); if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Copy control edge from %s to %s failed.", src_node->GetName().c_str(), - dst_node->GetName().c_str()); + GELOGE(FAILED, "Copy control edge from %s to %s failed.", trans_node->GetName().c_str(), + pre_node->GetName().c_str()); return ret; } // 6.2 Isolate and delete T1 - ret = IsolateAndDeleteNode(src_node, {}); + ret = IsolateAndDeleteNode(trans_node, {}); if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", src_node->GetName().c_str(), - src_node->GetType().c_str()); + GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", trans_node->GetName().c_str(), + trans_node->GetType().c_str()); return ret; } - GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", src_node->GetName().c_str()); + GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", trans_node->GetName().c_str()); } return SUCCESS; } diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.h b/src/ge/graph/passes/transop_symmetry_elimination_pass.h index 7f7409b7..2c89ed48 100644 --- a/src/ge/graph/passes/transop_symmetry_elimination_pass.h +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.h @@ -58,15 +58,6 @@ class TransOpSymmetryEliminationPass : public BaseNodePass { /// static bool JudgeTransposeDBack2Raw(const NodePtr &src_node, const NodePtr &dst_node); - /// - /// two transform nodes can not be offset if there is precision loss, like FP32->BOOL BOOL->FP32. - /// keep this pair of transform nodes if it has precision loss. - /// @param src_node: the front node - /// @param dst_node: the back node - /// @return True or False, whether can be offset or not - /// - static bool CheckPrecisionLoss(const NodePtr &src_node); - /// /// two transform nodes can be offset like A->T1->T2->B /// 1.unlink T1->T2 @@ -83,6 +74,8 @@ class TransOpSymmetryEliminationPass : public BaseNodePass { /// Status EliminateTransOp(NodePtr &src_node, const OutDataAnchorPtr &src_out_anchor, NodePtr &dst_node, const InDataAnchorPtr &dst_in_anchor); + + Status RemoveTransOpWithoutOutput(NodePtr &pre_node, NodePtr &trans_node); }; } // namespace ge diff --git a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc index 1d97d9a1..3080e886 100644 --- a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -22,6 +22,7 @@ #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" +#include "graph/common/transop_util.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" @@ -89,7 +90,7 @@ bool TransOpWithoutReshapeFusionPass::FormatContinuousCheck(const OutDataAnchorP } if (in_op->GetType() == CAST || out_op->GetType() == CAST) { - return true; + return TransOpUtil::CheckPrecisionLoss(in_node); } if (in_op_desc->GetFormat() == FORMAT_ND) { diff --git a/src/ge/graph/passes/unused_args_clean_pass.cc b/src/ge/graph/passes/unused_args_clean_pass.cc new file mode 100644 index 00000000..62094631 --- /dev/null +++ b/src/ge/graph/passes/unused_args_clean_pass.cc @@ -0,0 +1,206 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "unused_args_clean_pass.h" + +#include "graph/utils/node_utils.h" + +namespace ge { +Status UnusedArgsCleanPass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + if (graph->GetParentGraph() != nullptr) { + GELOGD("Subgraph %s skip the UnusedArgsCleanPass", graph->GetName().c_str()); + return SUCCESS; + } + GELOGD("Begin to run Unused args clean on graph: %s", graph->GetName().c_str()); + + for (const auto &node : graph->GetDirectNode()) { + if (node->GetType() != CASE) { + continue; + } + + const auto &func_desc = node->GetOpDesc(); + map> graph_nodes; + if (ClassifyDataNodes(graph, func_desc, graph_nodes) != SUCCESS) { + return FAILED; + } + + // {subgraph0, {{0, Data}, {1, Data}, {2, Data}, {3, Data}, ..., {n, Data}}} + // {subgraph1, {{0, Data}, {1, Data}, {2, Data}, {3, Data}, ..., {n, Data}}} + // {subgraph2, {{0, Data}, {1, Data}, {2, Data}, {3, Data}, ..., {n, Data}}} + uint32_t unused_args_num = 0; + uint32_t inputs_args_num = func_desc->GetInputsSize(); + for (size_t i = 1; i < inputs_args_num; ++i) { + if (UnusedInputTensor(graph_nodes, node, i)) { + unused_args_num++; + } else { + (void)UpdateInputTensor(graph_nodes, node, i, unused_args_num); + } + } + + (void)NodeUtils::RemoveInputAnchor(node, inputs_args_num - unused_args_num); + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Create nodes for root graph. +/// @param [in] graph_nodes: Data groups of subgraph. +/// @param [in] func_node: functional Node of Case. +/// @param [in] parent_index: parent index for check. +/// @return true: unused / false: used +/// +bool UnusedArgsCleanPass::UnusedInputTensor(const map> &graph_nodes, + const NodePtr &func_node, uint32_t parent_index) { + for (const auto &item : graph_nodes) { + const auto &nodes = item.second; + const auto it = nodes.find(parent_index); + if (it == nodes.end()) { // not used. + continue; + } + + const auto &data = it->second; + for (const auto out_anchor : data->GetAllOutAnchors()) { + for (const auto in_anchor : out_anchor->GetPeerAnchors()) { + if (in_anchor == nullptr) { + continue; + } + + return false; + } + } + } + + return RemoveInputTensor(graph_nodes, func_node, parent_index) == SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Get all Data nodes for all subgraph. +/// @param [in] graph: Root compute graph. +/// @param [in] func_desc: functional OpDesc of Case. +/// @param [out] graph_nodes: Data groups of subgraph. +/// @return 0: SUCCESS / others: FAILED +/// +Status UnusedArgsCleanPass::ClassifyDataNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, + map> &graph_nodes) { + for (const auto &name : func_desc->GetSubgraphInstanceNames()) { + const auto &subgraph = graph->GetSubgraph(name); + if (subgraph == nullptr) { + GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s", name.c_str()); + return GE_GRAPH_EMPTY_SUBGRAPH; + } + + auto &data_nodes = graph_nodes[subgraph]; + for (auto &data : subgraph->GetDirectNode()) { + if (data->GetType() != DATA) { + continue; + } + + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Parent index not found, name: %s", data->GetName().c_str()); + return FAILED; + } + + data_nodes[parent_index] = data; + GELOGD("%s, Parent index: %u, Data: %s", subgraph->GetName().c_str(), parent_index, data->GetName().c_str()); + } + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Update Case input Tensor. +/// @param [in] graph_nodes: Data groups of subgraph. +/// @param [in] func_node: functional Node of Case. +/// @param [in] parent_index: parent index for update. +/// @param [in] unused_num: unused args num. +/// @return 0: SUCCESS / others: FAILED +/// +Status UnusedArgsCleanPass::UpdateInputTensor(const map> &graph_nodes, + const NodePtr &func_node, uint32_t parent_index, uint32_t unused_num) { + if (unused_num == 0) { + return SUCCESS; + } + + uint32_t update_index = parent_index - unused_num; + for (const auto &item : graph_nodes) { + const auto &nodes = item.second; + const auto it = nodes.find(parent_index); + if (it == nodes.end()) { // not used. + continue; + } + const auto data = it->second; + + if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, update_index)) { + GELOGE(FAILED, "Set parent index failed, name: %s", data->GetName().c_str()); + return FAILED; + } + } + + const auto &new_anchor = func_node->GetInDataAnchor(update_index); + const auto &old_anchor = func_node->GetInDataAnchor(parent_index); + const auto &out_anchor = old_anchor->GetPeerOutAnchor(); + const auto &out_node = out_anchor->GetOwnerNode(); + + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, new_anchor), "Add edge failed"); + GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u", func_node->GetName().c_str(), + out_node->GetName().c_str(), parent_index, update_index); + + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed"); + GELOGI("Remove edge success, func node: %s, node: %s", func_node->GetName().c_str(), out_node->GetName().c_str()); + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Remove Case input Tensor. +/// @param [in] graph_nodes: Data groups of subgraph. +/// @param [in] func_node: functional Node of Case. +/// @param [in] parent_index: parent index for remove. +/// @return 0: SUCCESS / others: FAILED +/// +Status UnusedArgsCleanPass::RemoveInputTensor(const map> &graph_nodes, + const NodePtr &func_node, uint32_t parent_index) { + for (const auto &item : graph_nodes) { + const auto &graph = item.first; + const auto &nodes = item.second; + const auto it = nodes.find(parent_index); + if (it == nodes.end()) { // not used. + continue; + } + + const auto &data = it->second; + GE_CHK_GRAPH_STATUS_RET(graph->RemoveNode(data), "Remove node failed: %s", data->GetName().c_str()); + GELOGI("Remove Node: %s %s", graph->GetName().c_str(), data->GetName().c_str()); + } + + const auto &old_anchor = func_node->GetInDataAnchor(parent_index); + const auto &out_anchor = old_anchor->GetPeerOutAnchor(); + const auto &out_node = out_anchor->GetOwnerNode(); + + GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed"); + GELOGI("Remove edge: %s %s", out_node->GetName().c_str(), func_node->GetName().c_str()); + + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/passes/unused_args_clean_pass.h b/src/ge/graph/passes/unused_args_clean_pass.h new file mode 100644 index 00000000..59f7b647 --- /dev/null +++ b/src/ge/graph/passes/unused_args_clean_pass.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_CASE_ARGS_CLEAN_H_ +#define GE_COMMON_CASE_ARGS_CLEAN_H_ + +#include "graph/types.h" +#include "inc/graph_pass.h" + +#include +#include +#include +#include + +using std::map; +using std::set; + +namespace ge { +class UnusedArgsCleanPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + + private: + /// + /// @ingroup ge + /// @brief Create nodes for root graph. + /// @param [in] graph_nodes: Data groups of subgraph. + /// @param [in] func_node: functional Node of Case. + /// @param [in] parent_index: parent index for check. + /// @return true: unused / false: used + /// + bool UnusedInputTensor(const map> &graph_nodes, const NodePtr &func_node, + uint32_t parent_index); + + /// + /// @ingroup ge + /// @brief Get all Data nodes for all subgraph. + /// @param [in] graph: Root compute graph. + /// @param [in] func_desc: functional OpDesc of Case. + /// @param [out] graph_nodes: Data groups of subgraph. + /// @return 0: SUCCESS / others: FAILED + /// + Status ClassifyDataNodes(const ComputeGraphPtr &graph, const OpDescPtr &func_desc, + map> &graph_nodes); + + /// + /// @ingroup ge + /// @brief Remove Case input Tensor. + /// @param [in] graph_nodes: Data groups of subgraph. + /// @param [in] func_node: functional Node of Case. + /// @param [in] parent_index: parent index for remove. + /// @return 0: SUCCESS / others: FAILED + /// + Status RemoveInputTensor(const map> &graph_nodes, const NodePtr &func_node, + uint32_t parent_index); + + /// + /// @ingroup ge + /// @brief Update Case input Tensor. + /// @param [in] graph_nodes: Data groups of subgraph. + /// @param [in] func_node: functional Node of Case. + /// @param [in] parent_index: parent index for update. + /// @param [in] unused_num: unused args num. + /// @return 0: SUCCESS / others: FAILED + /// + Status UpdateInputTensor(const map> &graph_nodes, const NodePtr &func_node, + uint32_t parent_index, uint32_t unused_num); +}; +} // namespace ge +#endif // GE_COMMON_CASE_ARGS_CLEAN_H_ diff --git a/src/ge/graph/passes/variable_prepare_op_pass.cc b/src/ge/graph/passes/variable_prepare_op_pass.cc index d93e1003..f0e11735 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.cc +++ b/src/ge/graph/passes/variable_prepare_op_pass.cc @@ -18,7 +18,6 @@ #include #include #include -#include #include "common/ge/ge_util.h" #include "external/graph/graph.h" #include "framework/common/debug/ge_log.h" @@ -28,8 +27,8 @@ #include "graph/utils/tensor_utils.h" namespace ge { -std::map> VariablePrepareOpPass::ref_node_without_prototype_map_{ - {REFSWITCH, {{0, 0}, {0, 1}}}}; +std::map>> VariablePrepareOpPass::ref_node_without_prototype_map_ = { + {REFSWITCH, {{0, {0, 1}}}}}; Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); @@ -55,14 +54,6 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { } } } - - for (auto iter = ref_input_output_map_.begin(); iter != ref_input_output_map_.end(); ++iter) { - GELOGI("ref type:[ %s ]", iter->first.c_str()); - auto index_map = iter->second; - for (auto index_iter = index_map.begin(); index_iter != index_map.end(); ++index_iter) { - GELOGI("{ %d : %d }", index_iter->first, index_iter->second); - } - } return SUCCESS; } @@ -74,54 +65,53 @@ Status VariablePrepareOpPass::DealVariableNode(NodePtr &var_node) { InDataAnchorPtr dst_in_data_anchor = dst_node_and_inanchor.second; GE_CHECK_NOTNULL(dst_in_data_anchor); auto input_index = dst_in_data_anchor->GetIdx(); - int out_index = GetWritableNodeOutIndex(dst_node, input_index); - if (out_index >= 0) { - Status ret = DealWritableNode(dst_node, input_index, var_node); - if (ret != SUCCESS) { - GELOGE(FAILED, "Deal writable node[%s] failed, input index: %d, var: %s.", dst_node->GetName().c_str(), - input_index, var_node->GetName().c_str()); - return FAILED; + vector ref_output_indexes; + GetWritableNodeOutIndex(dst_node, input_index, ref_output_indexes); + if (!ref_output_indexes.empty()) { + for (auto output_index : ref_output_indexes) { + Status ret = DealWritableNode(dst_node, input_index, output_index, var_node); + if (ret != SUCCESS) { + GELOGE(FAILED, "Deal writable node[%s] failed, input index: %d, var: %s.", dst_node->GetName().c_str(), + input_index, var_node->GetName().c_str()); + return FAILED; + } } } } return SUCCESS; } -Status VariablePrepareOpPass::DealWritableNode(const ge::NodePtr &writable_node, int input_index, +Status VariablePrepareOpPass::DealWritableNode(const ge::NodePtr &writable_node, int input_index, int output_index, const ge::NodePtr &var_node) { // Find the last ref node: // If the ref input has corresponding output, add variable ref after it. // If the ref input has no corresponding output, insert RefIdentity and variable ref before it. // If ref node with control output was found while finding the last ref node, add variable ref after it. - std::stack> nodes_to_check; - nodes_to_check.push({writable_node, input_index}); + std::stack>> nodes_to_check; + nodes_to_check.push({writable_node, {input_index, output_index}}); while (!nodes_to_check.empty()) { auto node_index = nodes_to_check.top(); nodes_to_check.pop(); auto cur_node = node_index.first; - int cur_input_index = node_index.second; + int cur_input_index = node_index.second.first; + int cur_output_index = node_index.second.second; // Collect ref node after cur node const auto nodes_size = nodes_to_check.size(); // Add peer ref output node of current node to stack - CHECK_FALSE_EXEC(GetPeerNodeOfRefInput(cur_node, cur_input_index, nodes_to_check) == SUCCESS, - GELOGE(FAILED, "GetPeerNodeOfRefInput for node[%s] failed.", cur_node->GetName().c_str()); - return FAILED); - auto output_index = GetWritableNodeOutIndex(cur_node, cur_input_index); - CHECK_FALSE_EXEC(output_index >= 0, - GELOGE(FAILED, "Get writable node[%s] ref input[%d]'s corresponding out index failed: %d.", - cur_node->GetName().c_str(), cur_input_index, output_index); + CHECK_FALSE_EXEC(GetPeerNodeOfRefOutput(cur_node, cur_output_index, nodes_to_check) == SUCCESS, + GELOGE(FAILED, "GetPeerNodeOfRefOutput for node[%s] failed.", cur_node->GetName().c_str()); return FAILED); if (nodes_size == nodes_to_check.size()) { const auto &op_desc = cur_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - // No need to add variable_ref for frameworkop + // No need to add variable_ref for framework op if (op_desc->GetType() == FRAMEWORKOP) { GELOGD("No need to add variable_ref for frameworkop"); continue; } - if (static_cast(output_index) < op_desc->GetOutputsSize()) { + if (static_cast(cur_output_index) < op_desc->GetOutputsSize()) { // Add variable ref node after ref output for final ref node - CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, output_index) == SUCCESS, + CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, cur_output_index) == SUCCESS, GELOGE(FAILED, "Add variable ref failed"); return FAILED); } else { @@ -134,7 +124,7 @@ Status VariablePrepareOpPass::DealWritableNode(const ge::NodePtr &writable_node, } if (HasControlOut(cur_node)) { // Add variable ref node after ref output for ref node has control output. - CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, output_index) == SUCCESS, + CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, cur_output_index) == SUCCESS, GELOGE(FAILED, "Add variable ref failed"); return FAILED); } @@ -142,11 +132,10 @@ Status VariablePrepareOpPass::DealWritableNode(const ge::NodePtr &writable_node, return SUCCESS; } -Status VariablePrepareOpPass::GetPeerNodeOfRefInput(const ge::NodePtr &node, int input_index, - std::stack> &nodes) { - auto output_index = GetWritableNodeOutIndex(node, input_index); - if (output_index == -1) { - GELOGE(PARAM_INVALID, "Node[%s] is not a ref node.", node->GetName().c_str()); +Status VariablePrepareOpPass::GetPeerNodeOfRefOutput(const ge::NodePtr &node, int output_index, + std::stack>> &nodes) { + if (output_index < 0) { + GELOGE(PARAM_INVALID, "Invalid ref output index: %s-%d.", node->GetName().c_str(), output_index); return PARAM_INVALID; } const auto &op_desc = node->GetOpDesc(); @@ -166,8 +155,10 @@ Status VariablePrepareOpPass::GetPeerNodeOfRefInput(const ge::NodePtr &node, int continue; } const int peer_in_index = peer_in_anchor->GetIdx(); - if (GetWritableNodeOutIndex(peer_node, peer_in_index) != -1) { - nodes.push({peer_node, peer_in_index}); + vector ref_output_indexes; + GetWritableNodeOutIndex(peer_node, peer_in_index, ref_output_indexes); + for (auto ref_output_index : ref_output_indexes) { + nodes.push({peer_node, {peer_in_index, ref_output_index}}); } } return SUCCESS; @@ -353,9 +344,10 @@ ge::NodePtr VariablePrepareOpPass::CreateVariableRef(const std::string &variable return variable_ref_node; } -int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index) { +void VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index, + std::vector &output_indexes) { if (node == nullptr) { - return -1; + return; } GELOGD("get writable node and input index %s:%d", node->GetName().c_str(), input_index); auto node_type = node->GetType(); @@ -363,9 +355,11 @@ int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int inpu std::string original_type; GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail")); GELOGD("find frameworkop: [%s], original type is %s", node->GetName().c_str(), original_type.c_str()); - return FindRefOutIndex(original_type, input_index, ref_node_without_prototype_map_); + FindRefOutIndex(original_type, input_index, ref_node_without_prototype_map_, output_indexes); + return; } - return FindRefOutIndex(node_type, input_index, ref_input_output_map_); + FindRefOutIndex(node_type, input_index, ref_input_output_map_, output_indexes); + return; } void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node) { @@ -378,29 +372,32 @@ void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node // Record the index of output with the same name as input, thinking of them as a pair of ref input and output. const int out_index = op_desc->GetOutputIndexByName(name_index.first); if (out_index != -1) { - ref_input_output_map_[node->GetType()][name_index.second] = out_index; + ref_input_output_map_[node->GetType()][name_index.second] = {out_index}; continue; } // Record the ref input without corresponding output. const auto &input_desc = op_desc->GetInputDesc(name_index.second); if (!input_desc.GetRefPortIndex().empty()) { - ref_input_output_map_[node->GetType()][name_index.second] = static_cast(op_desc->GetOutputsSize()); + ref_input_output_map_[node->GetType()][name_index.second] = {static_cast(op_desc->GetOutputsSize())}; } } } -int VariablePrepareOpPass::FindRefOutIndex(const std::string &node_type, int input_index, - const std::map> &ref_map) { +void VariablePrepareOpPass::FindRefOutIndex(const std::string &node_type, int input_index, + const std::map>> &ref_map, + std::vector &output_indexes) { auto node_iter = ref_map.find(node_type); if (node_iter == ref_map.end()) { - return -1; + return; } auto index_iter = node_iter->second.find(input_index); if (index_iter == node_iter->second.end()) { - return -1; + return; + } + for (const auto &out_index : index_iter->second) { + output_indexes.emplace_back(out_index); } - return index_iter->second; } Status VariablePrepareOpPass::CheckStreamLabel(const ge::NodePtr &var_ref_node, diff --git a/src/ge/graph/passes/variable_prepare_op_pass.h b/src/ge/graph/passes/variable_prepare_op_pass.h index f024a464..563a9be5 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.h +++ b/src/ge/graph/passes/variable_prepare_op_pass.h @@ -31,22 +31,25 @@ class VariablePrepareOpPass : public GraphPass { private: Status DealVariableNode(ge::NodePtr &node); - Status DealWritableNode(const ge::NodePtr &writable_node, int input_index, const ge::NodePtr &var_node); - Status GetPeerNodeOfRefInput(const ge::NodePtr &node, int input_index, std::stack> &nodes); + Status DealWritableNode(const ge::NodePtr &writable_node, int input_index, int output_index, + const ge::NodePtr &var_node); + Status GetPeerNodeOfRefOutput(const ge::NodePtr &node, int output_index, + std::stack>> &nodes); Status AddVariableRef(ge::NodePtr &node, const ge::NodePtr &var_node, int index); Status InsertVariableRef(ge::NodePtr &node, int in_index, const ge::NodePtr &var_node); Status AddControlEdge(const ge::NodePtr &node, const ge::NodePtr &variable_ref_node); NodePtr CreateVariableRef(const std::string &variable_ref_name, const ge::NodePtr &var_node); NodePtr CreateRefIdentity(const std::string &ref_identity_name, const ge::NodePtr &node, uint32_t input_index); - int GetWritableNodeOutIndex(const NodePtr &node, int input_index); + void GetWritableNodeOutIndex(const NodePtr &node, int input_index, std::vector &output_indexes); void GenerateRefTypeAndInputOutputMap(const NodePtr &node); - int FindRefOutIndex(const std::string &node_type, int input_index, - const std::map> &ref_map); + void FindRefOutIndex(const std::string &node_type, int input_index, + const std::map>> &ref_map, + std::vector &output_indexes); Status CheckStreamLabel(const ge::NodePtr &var_ref_node, const ge::NodePtr &final_writable_node); bool HasControlOut(const ge::NodePtr &node); - std::map> ref_input_output_map_; - static std::map> ref_node_without_prototype_map_; + std::map>> ref_input_output_map_; + static std::map>> ref_node_without_prototype_map_; }; } // namespace ge diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.cc b/src/ge/graph/passes/variable_ref_delete_op_pass.cc index 3487df47..90cfd747 100644 --- a/src/ge/graph/passes/variable_ref_delete_op_pass.cc +++ b/src/ge/graph/passes/variable_ref_delete_op_pass.cc @@ -72,6 +72,10 @@ Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge:: if (is_set_str) { GELOGI("[%s-%d]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), index, ref_var_src_var_name.c_str()); + } else { + GELOGE(FAILED, "[%s-%d]: add attr [REF_VAR_SRC_VAR_NAME: %s ] failed", peer_node->GetName().c_str(), index, + ref_var_src_var_name.c_str()); + return FAILED; } // remove variable_ref if (GraphUtils::IsolateNode(variable_ref, {0}) != GRAPH_SUCCESS) { diff --git a/src/ge/graph/preprocess/graph_preprocess.cc b/src/ge/graph/preprocess/graph_preprocess.cc index 3d0f1514..4df22cfc 100644 --- a/src/ge/graph/preprocess/graph_preprocess.cc +++ b/src/ge/graph/preprocess/graph_preprocess.cc @@ -42,6 +42,7 @@ #include "graph/passes/addn_pass.h" #include "graph/passes/aicpu_constant_folding_pass.h" #include "graph/passes/assert_pass.h" +#include "graph/passes/assign_pass.h" #include "graph/passes/base_pass.h" #include "graph/passes/common_subexpression_elimination_pass.h" #include "graph/passes/cond_pass.h" @@ -82,12 +83,10 @@ #include "graph/passes/subgraph_pass.h" #include "graph/passes/switch_data_edges_bypass.h" #include "graph/passes/switch_dead_branch_elimination.h" -#include "graph/passes/switch_fusion_pass.h" #include "graph/passes/switch_logic_remove_pass.h" #include "graph/passes/merge_to_stream_merge_pass.h" #include "graph/passes/switch_to_stream_switch_pass.h" #include "graph/passes/attach_stream_label_pass.h" -#include "graph/passes/switch_split_pass.h" #include "graph/passes/unused_const_pass.h" #include "graph/passes/unused_op_remove_pass.h" #include "graph/passes/var_is_initialized_op_pass.h" @@ -127,6 +126,9 @@ static std::map output_type_str_to_datatype = { const char *const kMbatchSwitchnName = "mbatch-switch-name"; +// the size of user defined output datatype or format string after split by ":". +const size_t kUserDefinedElementCount = 2; + OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { GeTensorPtr tensor = MakeShared(); if (tensor == nullptr) { @@ -587,16 +589,6 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node return SUCCESS; } -bool CheckIfSetOutputType(std::string &output_type, ge::DataType &output_data_type) { - if (output_type_str_to_datatype.find(output_type) != output_type_str_to_datatype.end()) { - output_data_type = output_type_str_to_datatype[output_type]; - return true; - } else { - GELOGI("output_type [%s] is not set or set unexpected", output_type.c_str()); - return false; - } - return false; -} bool CheckOpType(const NodePtr &node, const std::string type) { if (node->GetType() == type) { return true; @@ -637,18 +629,18 @@ Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) { // In the dynamic shape process, transnode insertion by FE is advanced to the stage of whole // graph optimization, GE only sets the final data_type/format/shape information for variable, // data and netoutput, and no longer inserts the transnode. -Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { +Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node, DataType &dt_set) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); const GeTensorDescPtr &input = op_desc->MutableInputDesc(0); GE_CHECK_NOTNULL(input); ge::DataType src_dtype = input->GetDataType(); - if (src_dtype == DT_FLOAT16) { + if (src_dtype == dt_set) { GELOGI("The node name, %s dtype is fp16", node_ptr->GetName().c_str()); return SUCCESS; } - input->SetDataType(DT_FLOAT16); + input->SetDataType(dt_set); int64_t input_shape_size = 0; int64_t output_shape_size = 0; ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); @@ -660,7 +652,7 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP ge::TensorUtils::SetSize(*input, input_shape_size); const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(output); - output->SetDataType(DT_FLOAT16); + output->SetDataType(dt_set); ge::TensorUtils::SetSize(*output, output_shape_size); if (is_dynamic_batch) { GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); @@ -668,11 +660,11 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP GE_CHECK_NOTNULL(switchn_op_desc); auto switchn_input = switchn_op_desc->MutableInputDesc(0); GE_CHECK_NOTNULL(switchn_input); - switchn_input->SetDataType(DT_FLOAT16); + switchn_input->SetDataType(dt_set); for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); GE_CHECK_NOTNULL(switchn_output); - switchn_output->SetDataType(DT_FLOAT16); + switchn_output->SetDataType(dt_set); } } return SUCCESS; @@ -720,20 +712,11 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, No Status ProcessDataNodeDynShape(NodePtr &node_ptr) { auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - bool set_fp16 = false; - if (!ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_fp16", set_fp16) || !set_fp16) { + string set_dt_str; + if (!ge::AttrUtils::GetStr(node_ptr->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, set_dt_str)) { return SUCCESS; } - for (auto const &next_node : node_ptr->GetOutNodes()) { - if (next_node->GetType() == AIPP) { - ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"opname"}, {node_ptr->GetName()}); - GELOGE(INTERNAL_ERROR, - "This input op [%s] is linked to aipp, can not be set to fp16, " - "please check your atc parameter --insert_op_conf, --input_fp16_nodes.", - node_ptr->GetName().c_str()); - return FAILED; - } - } + DataType dt_set = TypeUtils::SerialStringToDataType(set_dt_str); GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); bool is_dynamic_batch = false; NodePtr switchn_node = nullptr; @@ -741,14 +724,14 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed"); return FAILED; } - if (ProcessInputFP16DynShape(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { + if (ProcessInputDtDynShape(node_ptr, is_dynamic_batch, switchn_node, dt_set) != SUCCESS) { GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); return FAILED; } // check if need to set format - bool set_format = false; - (void)ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_set_nc1hwc0", set_format); - if (set_format) { + string set_format; + bool ret = ge::AttrUtils::GetStr(node_ptr->GetOpDesc(), ATTR_ATC_USER_DEFINE_FORMAT, set_format); + if (ret && (!set_format.empty()) && TypeUtils::SerialStringToFormat(set_format) == FORMAT_NC1HWC0) { GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); @@ -841,46 +824,43 @@ Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorD return SUCCESS; } -bool NeedUpdateOutputByOutputTypeParm(std::string &output_type, NodePtr &src_node, uint32_t src_index, - ge::DataType &dt) { - if (CheckIfSetOutputType(output_type, dt)) { - GELOGI("All output node should be set datatype."); - return true; +bool NeedUpdateDtByOutputTypeParm(OpDescPtr &netout_desc, uint32_t &index, ge::DataType &dt) { + GE_CHECK_NOTNULL(netout_desc); + vector output_dt_str; + if (ge::AttrUtils::GetListStr(netout_desc, ATTR_ATC_USER_DEFINE_DATATYPE, output_dt_str)) { + for (auto dt_str : output_dt_str) { + vector dt_str_split = StringUtils::Split(dt_str, ':'); + if (dt_str_split.size() == kUserDefinedElementCount) { + if (dt_str_split[0] == to_string(index)) { + dt = TypeUtils::SerialStringToDataType(dt_str_split[1]); + GELOGI("Find netoutput node output %u datatype should be set %s .", index, + TypeUtils::DataTypeToSerialString(dt).c_str()); + return true; + } + } + } } - bool is_dynamic = CheckOpType(src_node, MERGE); - auto op_desc = src_node->GetOpDesc(); - if (is_dynamic) { - const InDataAnchorPtr &merge_input_anchor = src_node->GetInDataAnchor(0); - GE_RT_FALSE_CHECK_NOTNULL(merge_input_anchor); - const OutDataAnchorPtr &src_out_anchor = merge_input_anchor->GetPeerOutAnchor(); - GE_RT_FALSE_CHECK_NOTNULL(src_out_anchor); - src_index = static_cast(src_out_anchor->GetIdx()); - auto src_merge_node = src_out_anchor->GetOwnerNode(); - GE_RT_FALSE_CHECK_NOTNULL(src_merge_node); - op_desc = src_merge_node->GetOpDesc(); - GE_RT_FALSE_CHECK_NOTNULL(op_desc); - } - vector output_data_type_vec; - vector index_vec; - if ((ge::AttrUtils::GetListDataType(op_desc, "_output_dt_list", output_data_type_vec)) && - (ge::AttrUtils::GetListInt(op_desc, "_output_dt_index", index_vec))) { - if (output_data_type_vec.size() != index_vec.size()) { - GELOGW("output_dt_list size is not match output_dt_index size"); - return false; - } - for (uint32_t i = 0; i < index_vec.size(); ++i) { - if (index_vec[i] == src_index) { - dt = output_data_type_vec[i]; - GELOGI("Find node %s output %u datatype should set %s .", op_desc->GetName().c_str(), i, - TypeUtils::DataTypeToSerialString(dt).c_str()); - return true; + return false; +} + +bool NeedUpdateFormatByOutputTypeParm(OpDescPtr &netout_desc, uint32_t &index) { + GE_CHECK_NOTNULL(netout_desc); + vector output_format_str; + if (ge::AttrUtils::GetListStr(netout_desc, ATTR_ATC_USER_DEFINE_FORMAT, output_format_str)) { + for (auto format_str : output_format_str) { + vector format_str_split = StringUtils::Split(format_str, ':'); + if (format_str_split.size() == kUserDefinedElementCount) { + if (format_str_split[0] == to_string(index)) { + GELOGI("Find netoutput node output %u format should be set NC1HWC0.", index); + return true; + } } } } return false; } -Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { +Status ProcessNetoutputNodeDynShape(NodePtr &node) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); ge::DataType output_data_type = ge::DT_FLOAT; @@ -889,7 +869,6 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { auto index = static_cast(in_anchor->GetIdx()); auto peer_out = in_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out); - auto src_index = static_cast(peer_out->GetIdx()); auto src_node = peer_out->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); bool is_dynamic = CheckOpType(src_node, MERGE); @@ -899,11 +878,11 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { auto net_output_input_desc = op_desc->MutableInputDesc(index); GE_CHECK_NOTNULL(net_output_input_desc); - ge::GeShape src_shape = src_op_desc->GetOutputDesc(src_index).GetShape(); - ge::Format src_format = src_op_desc->GetOutputDesc(src_index).GetFormat(); - ge::DataType src_dtype = src_op_desc->GetOutputDesc(src_index).GetDataType(); + ge::GeShape old_shape = net_output_input_desc->GetShape(); + ge::Format old_format = net_output_input_desc->GetFormat(); + ge::DataType old_dtype = net_output_input_desc->GetDataType(); // Update datatype - if (NeedUpdateOutputByOutputTypeParm(output_type, src_node, src_index, output_data_type)) { + if (NeedUpdateDtByOutputTypeParm(op_desc, index, output_data_type)) { GELOGI("Enter into process output_type schedule"); net_output_input_desc->SetDataType(output_data_type); if (is_dynamic) { @@ -916,32 +895,15 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { merge_input->SetDataType(output_data_type); } } - continue; } - // output_node is not set,check if is_output_adjust_hw_layout is set - bool set_fp16_nc1hwc0 = false; - if (!is_dynamic) { - (void)AttrUtils::GetBool(src_op_desc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); - } else { - // need check dynamic scene, graph structure: node->merge->netoutput - const InDataAnchorPtr &merge_input_anchor = src_node->GetInDataAnchor(0); - GE_CHECK_NOTNULL(merge_input_anchor); - const OutDataAnchorPtr &src_out_anchor = merge_input_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_out_anchor); - auto src_merge_node = src_out_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(src_merge_node); - auto src_merge_node_opdesc = src_merge_node->GetOpDesc(); - (void)AttrUtils::GetBool(src_merge_node_opdesc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); - } - - if (set_fp16_nc1hwc0) { - GELOGI("Node [%s] should be set FP16 and NC1HWC0", src_op_desc->GetName().c_str()); - if ((src_format != FORMAT_NCHW) && (src_format != FORMAT_NHWC) && (src_format != FORMAT_NC1HWC0)) { + // check if is_output_adjust_hw_layout is set + if (NeedUpdateFormatByOutputTypeParm(op_desc, index)) { + if ((old_format != FORMAT_NCHW) && (old_format != FORMAT_NHWC) && (old_format != FORMAT_NC1HWC0)) { GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0."); return FAILED; } - GeTensorDesc src_desc(src_shape, src_format, src_dtype); - if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(src_desc, net_output_input_desc, src_node) != SUCCESS) { + GeTensorDesc old_desc(old_shape, old_format, old_dtype); + if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(old_desc, net_output_input_desc, src_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); return FAILED; } @@ -995,46 +957,6 @@ Status GraphPrepare::UpdateVariableFormats(ComputeGraphPtr &graph) { return SUCCESS; } -Status GraphPrepare::UpdateVariableFormatsDynShape(ComputeGraphPtr &graph) { - GE_CHECK_NOTNULL(graph); - auto var_names_to_refs = CollectVarNamesToRefs(graph); - for (auto &node : graph->GetAllNodes()) { - if (node == nullptr) { - continue; - } - if (node->GetType() != VARIABLE) { - continue; - } - auto trans_road = VarManager::Instance(graph->GetSessionID())->GetTransRoad(node->GetName()); - if (trans_road == nullptr) { - GELOGD("The variable %s does not have any trans road", node->GetName().c_str()); - continue; - } - - GELOGI("Recover the trans road for var %s reversely", node->GetName().c_str()); - - if (!(trans_road->empty())) { - auto ret = UpdateVarFormats(node, trans_road->rbegin()->output); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to update var formats for var %s", node->GetName().c_str()); - return INTERNAL_ERROR; - } - } - - auto iter = var_names_to_refs.find(node->GetName()); - if (iter != var_names_to_refs.end()) { - for (auto &var : iter->second) { - if (!(trans_road->empty()) && (UpdateVarFormats(var, trans_road->rbegin()->input) != SUCCESS)) { - GELOGE(INTERNAL_ERROR, "Failed to update var formats for ref var %s", var->GetName().c_str()); - return INTERNAL_ERROR; - } - } - } - } - - return SUCCESS; -} - void GraphPrepare::SetOptions(const ge::GraphManagerOptions &options) { options_ = options; } Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { @@ -1404,71 +1326,6 @@ Status GraphPrepare::UpdateDataNetOutputByStorageFormat() { return SUCCESS; } -void GraphPrepare::ProcessCCEFormat() { - static const char *const parser_priority = std::getenv("PARSER_PRIORITY"); - static const bool keep_cce = parser_priority != nullptr && string(parser_priority) == "cce"; - if (keep_cce) { - GELOGI("keep cce priority"); - for (const ge::NodePtr &n : compute_graph_->GetDirectNode()) { - auto node_op_desc = n->GetOpDesc(); - GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); - if (node_op_desc->GetType() == MULTIPLY || node_op_desc->GetType() == ASSIGN) { - auto input_size = static_cast(node_op_desc->GetInputsSize()); - for (uint32_t i = 0; i < input_size; ++i) { - ge::GeTensorDesc org_tensor_input = node_op_desc->GetInputDesc(i); - GELOGD("keep cce name:%s, type:%s", node_op_desc->GetName().c_str(), node_op_desc->GetType().c_str()); - if (org_tensor_input.GetFormat() == FORMAT_ND) { - org_tensor_input.SetFormat(FORMAT_NCHW); - org_tensor_input.SetOriginFormat(FORMAT_NCHW); - // [No need to check value] - (void)node_op_desc->UpdateInputDesc(i, org_tensor_input); - } - } - auto output_size = static_cast(node_op_desc->GetOutputsSize()); - for (uint32_t i = 0; i < output_size; ++i) { - ge::GeTensorDesc org_tensor_output = node_op_desc->GetOutputDesc(i); - GELOGD("keep cce name:%s, type:%s", node_op_desc->GetName().c_str(), node_op_desc->GetType().c_str()); - if (org_tensor_output.GetFormat() == FORMAT_ND) { - org_tensor_output.SetFormat(FORMAT_NCHW); - org_tensor_output.SetOriginFormat(FORMAT_NCHW); - // [No need to check value] - (void)node_op_desc->UpdateOutputDesc(i, org_tensor_output); - } - } - } - } - } -} - -Status GraphPrepare::OptimizeBeforeInfershape() { - PassManager graph_passes_before_infershape; - // Graph pass - try { - if (options_.train_graph_flag) { - (void)graph_passes_before_infershape.AddPass("OptimizeBeforeInfershape::SavePass", new SavePass); - } - (void)graph_passes_before_infershape.AddPass("OptimizeBeforeInfershape::NetOutputPass", new NetOutputPass); - } catch (std::bad_alloc &e) { - GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); - return INTERNAL_ERROR; - } - GE_TIMESTAMP_START(graph_passes_before_infershape); - Status ret = graph_passes_before_infershape.Run(compute_graph_); - GE_TIMESTAMP_END(graph_passes_before_infershape, "GraphPrepare::BeforeInfershape"); - bool status = (ret != SUCCESS && ret != NOT_CHANGED); - if (status) { - GELOGE(ret, "Run graph_passes_before_infershape failed, ret:%u.", ret); - return ret; - } - - graphStatus ret_topo = compute_graph_->TopologicalSorting(); - if (ret_topo != GRAPH_SUCCESS) { - GELOGE(ret_topo, "Graph topological sort failed, ret:%u.", ret_topo); - return ret_topo; - } - return SUCCESS; -} - Status GraphPrepare::SaveOriginalGraphToOmModel() { if (options_.save_original_model == "true") { ModelHelper model_helper; @@ -1482,89 +1339,6 @@ Status GraphPrepare::SaveOriginalGraphToOmModel() { return SUCCESS; } -Status GraphPrepare::Preprocess(const std::vector &user_input) { - // rtContext_t... - Status ret = SetRtContext(rtContext_t(), RT_CTX_GEN_MODE); - if (ret != SUCCESS) { - GELOGE(ret, "Set rt context failed."); - return ret; - } - - ret = CheckAndUpdateInput(user_input); - if (ret != SUCCESS) { - GELOGE(ret, "Check user input failed."); - return ret; - } - GE_DUMP(compute_graph_, "after_update_input"); - - GEPass ge_passes(compute_graph_); - NamesToPass names_to_passes; - ForPass for_pass; - names_to_passes.emplace_back("ForPass", &for_pass); - GE_TIMESTAMP_START(names_to_passes); - ret = ge_passes.Run(names_to_passes); - GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::ForPass"); - if (ret != SUCCESS) { - GELOGE(ret, "Run ForPass optimize for preprocess failed, ret:%u.", ret); - return ret; - } - GE_DUMP(compute_graph_, "after_for_pass"); - - GE_TIMESTAMP_START(netoutput_process); - ret = ProcessNetOutput(); - GE_TIMESTAMP_END(netoutput_process, "GraphPrepare::NetOutputProcess") - if (ret != SUCCESS) { - return ret; - } - GE_TIMESTAMP_START(multibatch_process); - ret = multibatch::ProcessMultiBatch(compute_graph_); - GE_TIMESTAMP_END(multibatch_process, "GraphPrepare::MultiBatchProcess") - if (ret != SUCCESS) { - GELOGE(ret, "Failed to do multi-batch processing"); - return ret; - } - GE_DUMP(compute_graph_, "after_multibatch_process"); - - ret = TryDoAipp(); - if (ret != SUCCESS) { - return ret; - } - - GE_TIMESTAMP_START(FormatAndShapeProcess); - ret = FormatAndShapeProcess(); - GE_TIMESTAMP_END(FormatAndShapeProcess, "GraphPrepare::FormatAndShapeProcess"); - if (ret != SUCCESS) { - GELOGE(ret, "FormatAndShape process failed"); - return ret; - } - GE_DUMP(compute_graph_, "after_inferformat_before_preprocess"); - - ProcessCCEFormat(); - - SaveOriginalGraphToOmModel(); - - GE_TIMESTAMP_START(OptimizeForPreprocess); - ret = OptimizeForPreprocess(); - GE_TIMESTAMP_END(OptimizeForPreprocess, "GraphPrepare::OptimizeForPreprocess"); - if (ret != SUCCESS) { - GELOGE(ret, "Optimize for preprocess failed."); - return ret; - } - GELOGI("Optimize for preprocess success."); - - GE_TIMESTAMP_START(UpdateVariableFormats); - ret = UpdateVariableFormats(compute_graph_); - GE_TIMESTAMP_END(UpdateVariableFormats, "GraphPrepare::UpdateVariableFormats"); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to update variables formats"); - return ret; - } - GELOGI("Update variable formats success."); - - GE_DUMP(compute_graph_, "Optimize_after_preprocess"); - return SUCCESS; -} - #define PP_RUN_AND_DUMP(name, func, ...) \ do { \ GE_RUN(Prepare, func, __VA_ARGS__); \ @@ -1683,81 +1457,6 @@ Status GraphPrepare::GenerateInfershapeGraph(ConstGraphPtr graph) { return SUCCESS; } -Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &user_input, - ge::ComputeGraphPtr &compute_graph, VarAccelerateCtrl &var_acc_ctrl, uint64_t session_id) { - domi::GetContext().type = static_cast(options_.framework_type); - - if (graph == nullptr) { - GELOGE(GE_GRAPH_NULL_INPUT, "Input Graph is NULL"); - return GE_GRAPH_NULL_INPUT; - } - const Graph &const_graph = *graph; - Status ret = Init(const_graph, session_id); - if (ret != SUCCESS) { - GELOGE(ret, "Init graph_prepare fail, ret:%u", ret); - return ret; - } - - GraphOptimize graph_optimize; - if (!options_.train_graph_flag && !domi::GetContext().train_flag) { - GE_DUMP(compute_graph_, "BeforeOriginalGraphForQuantize"); - GE_TIMESTAMP_START(OptimizeOriginalGraphForQuantize); - ret = graph_optimize.OptimizeOriginalGraphForQuantize(compute_graph_); - GE_TIMESTAMP_END(OptimizeOriginalGraphForQuantize, "GraphPrepare::OptimizeOriginalGraphForQuantize"); - if (ret != SUCCESS) { - GELOGE(ret, "originalGraph optimize for Quantize Failed"); - return ret; - } - } - GE_DUMP(compute_graph_, "BeforePreprocess"); - - GE_TIMESTAMP_START(Preprocess); - ret = Preprocess(user_input); - GE_TIMESTAMP_END(Preprocess, "GraphPrepare::Preprocess"); - if (ret != SUCCESS) { - GELOGE(ret, "Run graph_prepare fail, ret:%u", ret); - return ret; - } - // OriginalGraph optimize - ret = graph_optimize.SetOptions(options_); - GE_CHK_STATUS_RET(ret, "Graph optimize initial fail"); - if (options_.local_fmk_op_flag) { - graph_optimize.TranFrameOp(compute_graph_); - } - GE_DUMP(compute_graph_, "Prepare"); - - GE_TIMESTAMP_START(OptimizeOriginalGraph); - const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); - if (buffer_optimize_on != nullptr) { - ret = graph_optimize.NewOptimizeOriginalGraph(compute_graph_); - } else { - ret = graph_optimize.OptimizeOriginalGraph(compute_graph_); - } - GE_TIMESTAMP_END(OptimizeOriginalGraph, "GraphPrepare::OptimizeOriginalGraph"); - GE_DUMP(compute_graph_, "PreProcessOptimizeOriginalGraphAfter"); - if (ret != SUCCESS) { - GELOGE(ret, "originalGraph optimize Failed"); - return ret; - } - - GE_RETURN_IF_ERROR(RecordAIPPInfo(compute_graph_)); - - GE_TIMESTAMP_START(OptimizeBeforeSubGraph); - - if (buffer_optimize_on != nullptr) { - ret = NewOptimizeGraphBeforeSubGraph(var_acc_ctrl); - } else { - ret = OptimizeGraphBeforeSubGraph(); - } - GE_TIMESTAMP_END(OptimizeBeforeSubGraph, "GraphPrepare::OptimizeBeforeSubGraph"); - if (ret != SUCCESS) { - GELOGE(ret, "originalGraph optimize Failed"); - return ret; - } - compute_graph = compute_graph_; - return SUCCESS; -} - Status GraphPrepare::CheckConstOp() { for (auto &node_ptr : compute_graph_->GetAllNodes()) { GE_CHECK_NOTNULL(node_ptr); @@ -1959,6 +1658,7 @@ Status GraphPrepare::PrepareOptimize() { VarIsInitializedOpPass var_is_initialized_pass; ParallelConcatStartOpPass parallel_concat_start_op_pass; IdentityPass identity_pass(false); + AssignPass assign_pass; SnapshotPass snapshot_pass; if (!options_.train_graph_flag) { names_to_passes.emplace_back("DropOutPass", &dropout_pass); @@ -1973,6 +1673,9 @@ Status GraphPrepare::PrepareOptimize() { names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); names_to_passes.emplace_back("IdentityPass", &identity_pass); + if (GetContext().GetHostExecFlag()) { + names_to_passes.emplace_back("AssignPass", &assign_pass); + } GE_TIMESTAMP_START(names_to_passes); ret = ge_passes.Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::NamesToPasses"); @@ -2032,117 +1735,6 @@ void GraphPrepare::TypeConversionOfConstant() { } } -Status GraphPrepare::OptimizeForPreprocess() { - GELOGI("Start optimize for preprocess."); - PassManager original_graph_passes; - // Graph pass - try { - (void)original_graph_passes.AddPass("OptimizeForPreprocess::ConstantFuseSamePass", new ConstantFuseSamePass); - (void)original_graph_passes.AddPass("OptimizeForPreprocess::VariablePrepareOpPass", new VariablePrepareOpPass); - (void)original_graph_passes.AddPass("OptimizeForPreprocess::IteratorOpPass", new IteratorOpPass); - (void)original_graph_passes.AddPass("OptimizeForPreprocess::ShapeOperateOpRemovePass", - new ShapeOperateOpRemovePass); - (void)original_graph_passes.AddPass("OptimizeForPreprocess::ReplaceTransShapePass", new ReplaceTransShapePass); - } catch (std::bad_alloc &e) { - GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); - return INTERNAL_ERROR; - } - - GE_TIMESTAMP_START(original_graph_passes); - Status ret = original_graph_passes.Run(compute_graph_); - GE_TIMESTAMP_END(original_graph_passes, "GraphPrepare::OriginalGraphPasses"); - if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run graph passes optimize for preprocess failed, ret:%u.", ret); - return ret; - } - // New pass - GEPass ge_passes(compute_graph_); - NamesToPass names_to_passes; - EnterPass enter_pass; - names_to_passes.emplace_back("EnterPass", &enter_pass); - CondPass cond_pass; - names_to_passes.emplace_back("CondPass", &cond_pass); - AddNPass addn_pass; - names_to_passes.emplace_back("AddNPass", &addn_pass); - PrintOpPass print_pass; - if (options_.enable_print_op_pass) { - names_to_passes.emplace_back("PrintOpPass", &print_pass); - } - NoUseReshapeRemovePass no_use_reshape_remove_pass; - names_to_passes.emplace_back("NoUseReshapeRemovePass", &no_use_reshape_remove_pass); - - // for infer - DropOutPass dropout_pass; - AssertPass assert_pass; - if (!options_.train_graph_flag) { - names_to_passes.emplace_back("DropOutPass", &dropout_pass); - names_to_passes.emplace_back("AssertPass", &assert_pass); - } - UnusedConstPass unused_const_pass; - names_to_passes.emplace_back("UnusedConstPass", &unused_const_pass); - StopGradientPass stop_gradient_pass; - names_to_passes.emplace_back("StopGradientPass", &stop_gradient_pass); - PreventGradientPass prevent_gradient_pass; - names_to_passes.emplace_back("PreventGradientPass", &prevent_gradient_pass); - PlaceholderWithDefaultPass placeholder_with_default_pass; - names_to_passes.emplace_back("PlaceholderWithDefaultPass", &placeholder_with_default_pass); - SnapshotPass snapshot_pass; - names_to_passes.emplace_back("SnapshotPass", &snapshot_pass); - GuaranteeConstPass guarantee_const_pass; - names_to_passes.emplace_back("GuaranteeConstPass", &guarantee_const_pass); - VarIsInitializedOpPass var_is_initialized_pass; - names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); - ParallelConcatStartOpPass parallel_concat_start_op_pass; - names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); - IdentityPass identity_pass(false); - names_to_passes.emplace_back("IdentityPass", &identity_pass); - SwitchDeadBranchElimination switch_dead_branch_elimination; - names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); - SwitchLogicRemovePass switch_logic_remove_pass; - names_to_passes.emplace_back("SwitchLogicRemovePass", &switch_logic_remove_pass); - MergePass merge_pass; - names_to_passes.emplace_back("MergePass", &merge_pass); - GE_TIMESTAMP_START(names_to_passes); - ret = ge_passes.Run(names_to_passes); - GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::NamesToPasses"); - if (ret != SUCCESS) { - GELOGE(ret, "Run ge_passes optimize for preprocess failed, ret:%u.", ret); - return ret; - } - - PassManager graph_pass; - try { - (void)graph_pass.AddPass("OptimizeForPreprocess::PrunePass", new PrunePass); - (void)graph_pass.AddPass("OptimizeForPreprocess::NextIterationPass", new NextIterationPass); - (void)graph_pass.AddPass("OptimizeForPreprocess::ControlTriggerPass", new ControlTriggerPass); - (void)graph_pass.AddPass("OptimizeForPreprocess::MergeToStreamMergePass", new MergeToStreamMergePass); - (void)graph_pass.AddPass("OptimizeForPreprocess::SwitchToStreamSwitchPass", new SwitchToStreamSwitchPass); - (void)graph_pass.AddPass("OptimizeForPreprocess::AttachStreamLabelPass", new AttachStreamLabelPass); - (void)graph_pass.AddPass("OptimizeForPreprocess::HcclMemcpyPass", new HcclMemcpyPass); - GE_IF_BOOL_EXEC(options_.train_graph_flag, - (void)graph_pass.AddPass("OptimizeForPreprocess::FlowCtrlPass", new FlowCtrlPass);); - } catch (std::bad_alloc &e) { - GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); - return INTERNAL_ERROR; - } - - ret = graph_pass.Run(compute_graph_); - if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run graph passes optimize for preprocess failed, ret:%u.", ret); - return ret; - } - - ret = compute_graph_->TopologicalSorting(); - if (ret != SUCCESS) { - GELOGE(ret, "Graph topological sort failed, ret:%u.", ret); - return ret; - } - - GELOGI("End optimize for preprocess."); - - return SUCCESS; -} - Status GraphPrepare::GraphEquivalentTransformation() { NamesToPass names_to_pass; ForPass for_pass; @@ -2186,117 +1778,6 @@ Status GraphPrepare::ProcessNetOutput() { return SUCCESS; } -Status GraphPrepare::NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_ctrl) { - GELOGD("NewOptimizeGraphBeforeSubGraph in"); - PassManager passes; - (void)passes.AddPass("NewOptimizeGraphBeforeSubGraph::CommonSubexpressionEliminationPass", - new (std::nothrow) CommonSubexpressionEliminationPass); - auto ret = passes.Run(compute_graph_); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to optimize for graph"); - return ret; - } - - GEPass ge_passes_for_shape(compute_graph_); - NamesToPass names_to_passes_for_shape; - CastRemovePass cast_remove_pass; - names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); - TransposeTransDataPass transpose_transdata_pass; - names_to_passes_for_shape.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); - GE_TIMESTAMP_START(ge_passes_for_shape); - ret = ge_passes_for_shape.Run(names_to_passes_for_shape); - GE_TIMESTAMP_END(ge_passes_for_shape, "GraphManager::GePassesForShape"); - if (ret != SUCCESS) { - GELOGE(ret, "Run ge_passes_for_shape optimize for OptimizeGraphBeforeSubGraph failed, ret:%d.", ret); - return ret; - } - - string options = "default"; - if (GetContext().GetOption("ge.exec.variable_acc", options) != SUCCESS) { - GELOGI("get ge.exec.variable_acc failed. set default value."); - } - PassManager pass_manager; - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::PermutePass", new (std::nothrow) PermutePass)) - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::VariablePrepareOpPass", - new (std::nothrow) VariablePrepareOpPass)) - GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::VariableOpPass", - new (std::nothrow) VariableOpPass(&var_acc_ctrl)))) - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::TransOpWithoutReshapeFusionPass", - new (std::nothrow) TransOpWithoutReshapeFusionPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::TransOpDepthFusionPass", - new (std::nothrow) TransOpDepthFusionPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::TransOpBreadthFusionPass", - new (std::nothrow) TransOpBreadthFusionPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::VariableRefDeleteOpPass", - new (std::nothrow) VariableRefDeleteOpPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::SameTransdataBreadthFusionPass", - new (std::nothrow) SameTransdataBreadthFusionPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass("NewOptimizeGraphBeforeSubGraph::LinkGenMaskNodesPass", - new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))) - - GE_TIMESTAMP_START(pass_manager); - ret = pass_manager.Run(compute_graph_); - GE_TIMESTAMP_END(pass_manager, "GraphManager::BeforeSubGraph"); - if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); - return ret; - } - - // add variable attr for hccl broadcast,need to be removed after variable pass online - for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { - if (node->GetOpDesc()->GetType() != VARIABLE) { - continue; - } - if (IsBroadCastOpData(node)) { - AdjustBroadCastOpData(node); - } - if (IsAssignOpData(node)) { - AdjustAssignOpData(node); - } - } - - NamesToPass names_to_passes; - TransOpNearbyAllreduceFusionPass trans_op_nearby_allreduce_fusion_pass; - names_to_passes.emplace_back("TransOpNearbyAllreduceFusionPass", &trans_op_nearby_allreduce_fusion_pass); - ReshapeRemovePass reshape_remove_pass; - names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); - ConstantFoldingPass constant_folding_pass; - names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); - DimensionAdjustPass dimension_adjust_pass; - names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); - GEPass ge_passes(compute_graph_); - ret = ge_passes.Run(names_to_passes); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to optimize for graph"); - return ret; - } - return SUCCESS; -} - -Status GraphPrepare::OptimizeGraphBeforeSubGraph() { - PassManager passes; - (void)passes.AddPass("OptimizeGraphBeforeSubGraph::VariablePrepareOpPass", new (std::nothrow) VariablePrepareOpPass); - (void)passes.AddPass("OptimizeGraphBeforeSubGraph::CommonSubexpressionEliminationPass", - new (std::nothrow) CommonSubexpressionEliminationPass); - auto ret = passes.Run(compute_graph_); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to optimize for graph"); - return ret; - } - ConstantFoldingPass constant_folding_pass; - DimensionComputePass dimension_compute_pass; - NamesToPass names_to_passes; - names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); - names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); - GEPass ge_passes(compute_graph_); - ret = ge_passes.Run(names_to_passes); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to optimize for graph"); - return ret; - } - return SUCCESS; -} Status GraphPrepare::CheckAndUpdateInput(const std::vector &user_input) { compute_graph_->SetInputSize(user_input.size()); if (user_input.empty()) { @@ -2355,7 +1836,7 @@ Status GraphPrepare::UpdateInputOutputByOptions() { } if (node_ptr->GetType() == ge::NETOUTPUT) { - if (ProcessNetoutputNodeDynShape(node_ptr, options_.output_datatype) != SUCCESS) { + if (ProcessNetoutputNodeDynShape(node_ptr) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Process netoutput node failed"); return FAILED; } @@ -2364,21 +1845,6 @@ Status GraphPrepare::UpdateInputOutputByOptions() { return SUCCESS; } -bool GraphPrepare::IsBroadCastOpData(const ge::NodePtr &var_node) { - for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { - GE_RT_FALSE_CHECK_NOTNULL(out_anchor); - for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_RT_FALSE_CHECK_NOTNULL(in_anchor); - ge::NodePtr dst_node = in_anchor->GetOwnerNode(); - GE_RT_FALSE_CHECK_NOTNULL(dst_node); - if (dst_node->GetType() == HCOMBROADCAST || dst_node->GetType() == HVDCALLBACKBROADCAST) { - return true; - } - } - } - return false; -} - bool GraphPrepare::IsTansDataOpData(const ge::NodePtr &var_node) { for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { GE_RT_FALSE_CHECK_NOTNULL(out_anchor); @@ -2393,64 +1859,4 @@ bool GraphPrepare::IsTansDataOpData(const ge::NodePtr &var_node) { } return false; } - -bool GraphPrepare::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, - const map> &confirm_ops, ge::NodePtr &use_node) { - GE_RT_FALSE_CHECK_NOTNULL(in_anchor); - ge::NodePtr dst_node = in_anchor->GetOwnerNode(); - GE_RT_FALSE_CHECK_NOTNULL(dst_node); - ge::OpDescPtr dst_op_desc = dst_node->GetOpDesc(); - GE_RT_FALSE_CHECK_NOTNULL(dst_op_desc); - const string &dst_type = dst_op_desc->GetType(); - int input_index = in_anchor->GetIdx(); - - GELOGD("ConfirmUseOpAndIndex, var name %s, dst_type = %s, input index %d", dst_node->GetName().c_str(), - dst_type.c_str(), input_index); - - if (confirm_ops.count(dst_type) > 0) { - if (confirm_ops.at(dst_type).count(input_index) > 0) { - use_node = dst_node; - return true; - } - } - return false; -} - -bool GraphPrepare::ConfirmUseOpAndIndexByNode(const ge::NodePtr &var_node, - const map> &confirm_ops, ge::NodePtr &use_node) { - GE_RT_FALSE_CHECK_NOTNULL(var_node); - for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { - GE_RT_FALSE_CHECK_NOTNULL(out_anchor); - for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_RT_FALSE_CHECK_NOTNULL(in_anchor); - if (ConfirmUseOpAndIndexByAnchor(in_anchor, confirm_ops, use_node)) { - return true; - } - } - } - return false; -} -void GraphPrepare::AdjustBroadCastOpData(const ge::NodePtr &var_node) { - if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore")) { - GELOGW("set var_is_restore failed"); - } -} - -bool GraphPrepare::IsAssignOpData(const ge::NodePtr &var_node) { - GELOGD("IsAssignOpData var_node %s", var_node->GetName().c_str()); - std::map> assign_ops = {{ASSIGN, {0}}}; - - ge::NodePtr assign_node = nullptr; - if (ConfirmUseOpAndIndexByNode(var_node, assign_ops, assign_node)) { - return true; - } - - return false; -} - -void GraphPrepare::AdjustAssignOpData(const ge::NodePtr &var_node) { - if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_RESTORE, "var_is_restore")) { - GELOGW("SetStr var_is_restore failed"); - } -} } // namespace ge diff --git a/src/ge/graph/preprocess/graph_preprocess.h b/src/ge/graph/preprocess/graph_preprocess.h index 343791bd..7c6e4013 100644 --- a/src/ge/graph/preprocess/graph_preprocess.h +++ b/src/ge/graph/preprocess/graph_preprocess.h @@ -45,8 +45,6 @@ class GraphPrepare { virtual ~GraphPrepare(); GraphPrepare(const GraphPrepare &in) = delete; GraphPrepare &operator=(const GraphPrepare &in) = delete; - Status Prepare(ConstGraphPtr graph, const std::vector &user_input, ge::ComputeGraphPtr &compute_graph, - VarAccelerateCtrl &var_acc_ctrl, uint64_t session_id = 0); Status PrepareDynShape(ConstGraphPtr graph, const std::vector &user_input, ge::ComputeGraphPtr &compute_graph, uint64_t session_id = 0); Status RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph); @@ -57,7 +55,6 @@ class GraphPrepare { private: Status Init(const ge::Graph &graph, uint64_t session_id = 0); - Status Preprocess(const std::vector &user_input); Status CheckGraph(); Status CheckRefInputNode(const NodePtr &node, const std::string &input_name, const std::set &ref_nodes); Status CheckRefOp(); @@ -69,37 +66,19 @@ class GraphPrepare { Status VerifyConstOp(const NodePtr &node); Status CheckUserInput(const std::vector &user_input); Status UpdateDataNetOutputByStorageFormat(); - Status OptimizeForPreprocess(); Status PrepareOptimize(); Status InferShapeForPreprocess(); Status TryDoAipp(); Status UpdateVariableFormats(ComputeGraphPtr &graph); - Status UpdateVariableFormatsDynShape(ComputeGraphPtr &graph); Status FormatAndShapeProcess(); Status ResourcePairProcess(const std::string &action); - void ProcessCCEFormat(); - Status OptimizeBeforeInfershape(); - Status OptimizeGraphBeforeSubGraph(); - Status NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_ctrl); Status SaveOriginalGraphToOmModel(); Status ProcessNetOutput(); Status ProcessBeforeInfershape(); Status UpdateInputOutputByOptions(); - bool IsBroadCastOpData(const ge::NodePtr &var_node); bool IsTansDataOpData(const ge::NodePtr &var_node); - void AdjustBroadCastOpData(const ge::NodePtr &var_node); - - bool IsAssignOpData(const ge::NodePtr &var_node); - - void AdjustAssignOpData(const ge::NodePtr &var_node); - - bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, const map> &confirm_ops, - ge::NodePtr &use_node); - - bool ConfirmUseOpAndIndexByNode(const ge::NodePtr &var_node, const map> &confirm_ops, - ge::NodePtr &use_node); Status GraphEquivalentTransformation(); void TypeConversionOfConstant(); diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc index 55c7b427..c231ef15 100644 --- a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -100,6 +100,7 @@ const uint64_t kMinTransferShape = 3; const int kAippImageInputIndex = 0; const int kAippParamsInputIndex = 1; const int kAippDataOutputIndex = 0; +const int64_t kDynamicDim = -1; // the `format` must one NCHW or NHWC Status GetDataDimN(const ge::NodePtr &data_node, ge::Format format, int64_t &batch) { @@ -181,7 +182,7 @@ Status AippOp::InsertAippToGraph(ComputeGraphPtr &graph, std::string &aippConfig for (auto &out_in_anchors : target_edges) { auto iter = out_anchors_to_aipp.find(out_in_anchors.first); if (iter == out_anchors_to_aipp.end()) { - auto aipp = CreateAipp(graph, out_in_anchors.first, aippConfigPath, index); + auto aipp = CreateAipp(out_in_anchors.first, aippConfigPath, index); GE_CHECK_NOTNULL(aipp); out_anchors_to_aipp[out_in_anchors.first] = aipp; @@ -193,7 +194,7 @@ Status AippOp::InsertAippToGraph(ComputeGraphPtr &graph, std::string &aippConfig // add aipp data if needed if (GetAippMode() == domi::AippOpParams::dynamic) { - ret = CreateAippData(graph, aipp); + ret = CreateAippData(aipp); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to create aipp data for aipp %s data %s", aipp->GetName().c_str(), out_in_anchors.first->GetOwnerNode()->GetName().c_str()); @@ -215,10 +216,10 @@ Status AippOp::InsertAippToGraph(ComputeGraphPtr &graph, std::string &aippConfig return SUCCESS; } -NodePtr AippOp::CreateAipp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const std::string &aippConfigPath, const uint32_t &index) { - std::string current_name = - out_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor->GetIdx()) + "_huawei_aipp"; +NodePtr AippOp::CreateAipp(const OutDataAnchorPtr &out_anchor, const std::string &aippConfigPath, + const uint32_t &index) { + const auto &node = out_anchor->GetOwnerNode(); + std::string current_name = node->GetName() + "_" + std::to_string(out_anchor->GetIdx()) + "_huawei_aipp"; auto aipp_opdesc_ptr = MakeShared(current_name, AIPP); if (aipp_opdesc_ptr == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to alloc aipp desc, name %s", current_name.c_str()); @@ -226,42 +227,7 @@ NodePtr AippOp::CreateAipp(const ComputeGraphPtr &graph, const OutDataAnchorPtr } // Update attributes - GeAttrValue::NAMED_ATTRS aipp_attr; - ConvertParamToAttr(aipp_attr); - if (!AttrUtils::SetNamedAttrs(aipp_opdesc_ptr, ATTR_NAME_AIPP, aipp_attr)) { - GELOGE(INTERNAL_ERROR, "Set name attrs for aipp node failed"); - return nullptr; - } - if (!AttrUtils::SetStr(aipp_opdesc_ptr, kAippConfigPath, aippConfigPath)) { - GELOGE(INTERNAL_ERROR, "Set config file path attr for aipp node failed"); - return nullptr; - } - if (!AttrUtils::SetListStr(aipp_opdesc_ptr, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::vector())) { - GELOGE(INTERNAL_ERROR, "Set ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES attr for aipp node failed"); - return nullptr; - } - if (!AttrUtils::SetInt(aipp_opdesc_ptr, kCurrentAippIndex, index)) { - GELOGE(INTERNAL_ERROR, "Set kCurrentAippIndex attr for aipp node failed"); - return nullptr; - } - - // add input/output desc - GeTensorDesc tensor; - auto ret = aipp_opdesc_ptr->AddInputDesc("images", tensor); - if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to add input images for aipp node"); - return nullptr; - } - if (GetAippMode() == domi::AippOpParams::dynamic) { - ret = aipp_opdesc_ptr->AddOptionalInputDesc("params", tensor); - if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to add input params for aipp node"); - return nullptr; - } - } - ret = aipp_opdesc_ptr->AddOutputDesc("features", tensor); - if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to add output features for aipp node"); + if (AddAippAttrbutes(aipp_opdesc_ptr, aippConfigPath, index) != SUCCESS) { return nullptr; } @@ -287,7 +253,35 @@ NodePtr AippOp::CreateAipp(const ComputeGraphPtr &graph, const OutDataAnchorPtr return nullptr; } - return graph->AddNode(aipp_opdesc_ptr); + return node->GetOwnerComputeGraph()->AddNode(aipp_opdesc_ptr); +} + +Status AippOp::AddAippAttrbutes(const OpDescPtr &op_desc, const std::string &aipp_cfg_path, const uint32_t &index) { + GeAttrValue::NAMED_ATTRS aipp_attr; + ConvertParamToAttr(aipp_attr); + GE_CHK_BOOL_RET_STATUS(AttrUtils::SetNamedAttrs(op_desc, ATTR_NAME_AIPP, aipp_attr), INTERNAL_ERROR, + "Set name attrs for aipp node failed"); + + GE_CHK_BOOL_RET_STATUS(AttrUtils::SetStr(op_desc, kAippConfigPath, aipp_cfg_path), INTERNAL_ERROR, + "Set config file path attr for aipp node failed"); + + std::vector empty_names; + GE_CHK_BOOL_RET_STATUS(AttrUtils::SetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, empty_names), + INTERNAL_ERROR, "Set ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES attr for aipp node failed"); + + GE_CHK_BOOL_RET_STATUS(AttrUtils::SetInt(op_desc, kCurrentAippIndex, index), INTERNAL_ERROR, + "Set kCurrentAippIndex attr for aipp node failed"); + + // add input/output desc + GeTensorDesc tensor; + GE_CHK_GRAPH_STATUS_RET(op_desc->AddInputDesc("images", tensor), "Failed to add input images for aipp node"); + + if (GetAippMode() == domi::AippOpParams::dynamic) { + GE_CHK_GRAPH_STATUS_RET(op_desc->AddOptionalInputDesc("params", tensor), "Failed to add params for aipp node"); + } + GE_CHK_GRAPH_STATUS_RET(op_desc->AddOutputDesc("features", tensor), "Failed to add output features for aipp node"); + + return SUCCESS; } domi::AippOpParams::AippMode AippOp::GetAippMode() { return aipp_params_->aipp_mode(); } @@ -298,6 +292,12 @@ NodePtr AippOp::FindDataByIndex(const ComputeGraphPtr &graph, int rank) { if (node->GetType() != DATA) { continue; } + + // For functional multi batch, Skip Data for index. + if (node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)) { + continue; + } + // There is no `index` attribute on the `Data` node when compile in inference scene // so we can only use the order of all `Data` nodes to infer the data index if (data_index++ != rank) { @@ -306,6 +306,8 @@ NodePtr AippOp::FindDataByIndex(const ComputeGraphPtr &graph, int rank) { return node; } GELOGE(PARAM_INVALID, "Can not find the data node by index %d", rank); + string errormsg = "Can not find the data node by aipp parameter related_input_rank " + to_string(rank); + ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); return nullptr; } Status AippOp::GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr &target, @@ -315,6 +317,17 @@ Status AippOp::GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr GELOGE(PARAM_INVALID, "Get target input node for rank %d failed", rank); return PARAM_INVALID; } + auto data_opdesc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(data_opdesc); + string set_dt_str; + if (ge::AttrUtils::GetStr(data_opdesc, ATTR_ATC_USER_DEFINE_DATATYPE, set_dt_str)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"opname"}, {data_opdesc->GetName()}); + GELOGE(INTERNAL_ERROR, + "This input op [%s] is linked to aipp, can not be set to fp16, " + "please check your atc parameter --insert_op_conf, --input_fp16_nodes.", + data_opdesc->GetName().c_str()); + return PARAM_INVALID; + } // In scenario AIPP+CONV2D+POOLING, keep the aipp info to Data, since AIPP disappear after subgraph optimize GeAttrValue::NAMED_ATTRS aipp_attr; @@ -333,13 +346,22 @@ Status AippOp::GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr if (!edge_indexes.empty() && (*edge_indexes.rbegin() >= data_node->GetOutDataNodes().size())) { GELOGE(PARAM_INVALID, "input_edge_idx %u should smaller than out edge size of target input %zu", *edge_indexes.rbegin(), data_node->GetOutDataNodes().size()); + string errormsg = "The aipp parameter input_edge_idx should be smaller than the target input's outnodes."; + ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); return PARAM_INVALID; } target = data_node; + return GetStaticTargetNode(graph, data_node, target); +} + +Status AippOp::GetStaticTargetNode(const ComputeGraphPtr &graph, NodePtr &data_node, NodePtr &target) { + if (GetAippMode() != domi::AippOpParams::static_) { + return SUCCESS; + } + std::string related_node_name; - if ((GetAippMode() == domi::AippOpParams::static_) && - AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { + if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { if (related_node_name.empty()) { GELOGE(INTERNAL_ERROR, "The data node %s has switchn node flag, but the value is empty", data_node->GetName().c_str()); @@ -356,10 +378,26 @@ Status AippOp::GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr "Multi-batch/image size and static aipp for data %s, " "the aipp node will be insert after %s instead of origin data node", data_node->GetName().c_str(), switchn->GetName().c_str()); + + return SUCCESS; + } + + const auto out_anchor = data_node->GetOutDataAnchor(0); + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + if (in_anchor == nullptr) { + continue; + } + + const auto &case_node = in_anchor->GetOwnerNode(); + if (case_node->GetType() == CASE) { + target = case_node; + return SUCCESS; + } } return SUCCESS; } + Status AippOp::GetTargetPosition(ComputeGraphPtr graph, NodePtr &target_input, std::vector> &target_edges) { GE_CHECK_NOTNULL(graph); @@ -374,12 +412,39 @@ Status AippOp::GetTargetPosition(ComputeGraphPtr graph, NodePtr &target_input, } target_edges.clear(); - for (OutDataAnchorPtr &src_out : target_input->GetAllOutDataAnchors()) { - auto dst_ins = src_out->GetPeerInDataAnchors(); - for (uint32_t i = 0; i < dst_ins.size(); ++i) { - auto dst_in = dst_ins.at(i); - if (edge_indexes.empty() || edge_indexes.count(i) > 0) { - target_edges.emplace_back(src_out, dst_in); + if (target_input->GetType() != CASE) { + for (OutDataAnchorPtr &src_out : target_input->GetAllOutDataAnchors()) { + auto dst_ins = src_out->GetPeerInDataAnchors(); + for (uint32_t i = 0; i < dst_ins.size(); ++i) { + auto dst_in = dst_ins.at(i); + if (edge_indexes.empty() || edge_indexes.count(i) > 0) { + target_edges.emplace_back(src_out, dst_in); + } + } + } + } else { + const auto &func_desc = target_input->GetOpDesc(); + for (const auto &name : func_desc->GetSubgraphInstanceNames()) { + const auto &subgraph = graph->GetSubgraph(name); + if (subgraph == nullptr) { + GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s", name.c_str()); + return GE_GRAPH_EMPTY_SUBGRAPH; + } + + auto data_node = FindDataByIndex(subgraph, related_input_rank); + if (data_node == nullptr) { + GELOGE(PARAM_INVALID, "Get target input node for rank %d failed", related_input_rank); + return PARAM_INVALID; + } + + for (OutDataAnchorPtr &src_out : data_node->GetAllOutDataAnchors()) { + auto dst_ins = src_out->GetPeerInDataAnchors(); + for (uint32_t i = 0; i < dst_ins.size(); ++i) { + auto dst_in = dst_ins.at(i); + if (edge_indexes.empty() || edge_indexes.count(i) > 0) { + target_edges.emplace_back(src_out, dst_in); + } + } } } } @@ -639,22 +704,41 @@ void AippOp::ConvertParamToAttr(GeAttrValue::NAMED_ATTRS &aipp_attrs) { SAVE_AIPP_ATTR(support_rotation, GeAttrValue::BOOL); } } -Status AippOp::CreateAippData(const ComputeGraphPtr &graph, const NodePtr &aipp_node) { +Status AippOp::CreateAippData(const NodePtr &aipp_node) { GELOGD("Enter add aipp data node process."); // get previous node, it should be DATA auto data_node = aipp_node->GetInDataNodes().at(kAippImageInputIndex); - GE_CHECK_NOTNULL(data_node->GetOpDesc()); + auto data_op_desc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(data_op_desc); auto ori_data_format = GetAndCheckFormat(); if (ori_data_format != FORMAT_NCHW && ori_data_format != FORMAT_NHWC) { + string format_str = TypeUtils::FormatToSerialString(ori_data_format); GELOGE(PARAM_INVALID, "when dynamic aipp, input_format must be NCHW or NHWC, but [%s] format is %s", - data_node->GetName().c_str(), TypeUtils::FormatToSerialString(ori_data_format).c_str()); + data_node->GetName().c_str(), format_str.c_str()); + string reason = "format must be NCHW or NHWC in dynamic aipp process"; + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {data_node->GetName(), "format " + format_str, reason}); return PARAM_INVALID; } + // dynamic aipp shape HWC is not fixed, need to be set -1 + int64_t data_shape_n = 0; + // dynamic batch or HW, need acquire N from ATTR_MBATCH_ORIGIN_INPUT_DIMS + if (data_op_desc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + vector origin_input_dims; + (void)AttrUtils::GetListInt(data_op_desc, ATTR_MBATCH_ORIGIN_INPUT_DIMS, origin_input_dims); + if (!origin_input_dims.empty()) { + data_shape_n = origin_input_dims[0]; + } + } else { + data_shape_n = data_op_desc->MutableInputDesc(0)->GetShape().GetDim(0); + } + vector dynamic_aipp_linked_data_shape{data_shape_n, kDynamicDim, kDynamicDim, kDynamicDim}; + (void)AttrUtils::SetListInt(data_op_desc, "_dynamic_aipp_input_dims", dynamic_aipp_linked_data_shape); + int64_t batch_count = -1; - auto ret = GetDataDimN(data_node, ori_data_format, batch_count); - if (ret != ge::SUCCESS) { + if (GetDataDimN(data_node, ori_data_format, batch_count) != ge::SUCCESS) { GELOGE(PARAM_INVALID, "Get data_node dims and transfer to nchw_dims failed!"); return PARAM_INVALID; } @@ -670,6 +754,10 @@ Status AippOp::CreateAippData(const ComputeGraphPtr &graph, const NodePtr &aipp_ } GELOGI("Add aipp input data, batch count is %ld, max_dynamic_aipp_size is %ld", batch_count, max_dynamic_aipp_size); + return AddNodeToGraph(aipp_node, max_dynamic_aipp_size); +} + +Status AippOp::AddNodeToGraph(const NodePtr &aipp_node, int64_t max_dynamic_aipp_size) { std::vector input_shape_dim(1, max_dynamic_aipp_size); GeShape input_shape(input_shape_dim); // construct input tensor @@ -677,14 +765,19 @@ Status AippOp::CreateAippData(const ComputeGraphPtr &graph, const NodePtr &aipp_ TensorUtils::SetReuseInput(input_tensor, false); TensorUtils::SetSize(input_tensor, max_dynamic_aipp_size); - string node_name = kDynamicAippData; // Only flush subgraph name - if (graph->GetParentGraph() != nullptr) { - node_name = graph->GetName() + "_" + node_name; - } + const ComputeGraphPtr &graph = aipp_node->GetOwnerComputeGraph(); + string node_name = (graph->GetParentGraph() == nullptr) ? kDynamicAippData : (graph->GetName() + "_" + node_name); + // new add aipp_data ops for dynamic aipp param input OpDescPtr op_desc_ptr_data = MakeShared(node_name, AIPPDATA); GE_CHECK_NOTNULL(op_desc_ptr_data); + + // Add dynamic aipp config to aipp_data + GeAttrValue::NAMED_ATTRS aipp_attr; + ConvertParamToAttr(aipp_attr); + (void)AttrUtils::SetNamedAttrs(op_desc_ptr_data, ATTR_NAME_AIPP, aipp_attr); + auto stat1 = op_desc_ptr_data->AddInputDesc(input_tensor); GeShape output_shape(input_shape_dim); diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.h b/src/ge/graph/preprocess/insert_op/ge_aipp_op.h index d4f916e4..c98935ee 100644 --- a/src/ge/graph/preprocess/insert_op/ge_aipp_op.h +++ b/src/ge/graph/preprocess/insert_op/ge_aipp_op.h @@ -73,9 +73,11 @@ class AippOp : public InsertOpBase { void SetDtcDefaultValue(); NodePtr FindDataByIndex(const ComputeGraphPtr &graph, int rank); Status GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr &target, std::set &edge_indexes); - NodePtr CreateAipp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, - const std::string &aippConfigPath, const uint32_t &index); - Status CreateAippData(const ComputeGraphPtr &graph, const NodePtr &aipp); + Status GetStaticTargetNode(const ComputeGraphPtr &graph, NodePtr &data_node, NodePtr &target); + NodePtr CreateAipp(const OutDataAnchorPtr &out_anchor, const std::string &aippConfigPath, const uint32_t &index); + Status CreateAippData(const NodePtr &aipp); + Status AddNodeToGraph(const NodePtr &aipp_node, int64_t max_dynamic_aipp_size); + Status AddAippAttrbutes(const OpDescPtr &op_desc, const std::string &aipp_cfg_path, const uint32_t &index); domi::AippOpParams *aipp_params_ = nullptr; ge::NodePtr aipp_node_ = nullptr; diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index 8bb0c6c4..38bc595e 100644 --- a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -380,10 +380,6 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::mapaipp_mode() != domi::AippOpParams::static_) { - return SUCCESS; - } - for (auto out_data_anchor : node->GetAllOutDataAnchors()) { GE_CHECK_NOTNULL(out_data_anchor); auto peer_in_anchors = out_data_anchor->GetPeerInDataAnchors(); @@ -448,6 +444,8 @@ Status InsertNewOpUtil::RecordAIPPInfoToData(const ComputeGraphPtr &graph) { std::vector input_dims; std::vector output_dims; auto data_node = it.first; + auto data_op_desc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(data_op_desc); std::set aipps_or_switchs = it.second; if (aipps_or_switchs.size() != 1) { GELOGW("The number of successors swith or aipp of data is more than 1"); @@ -469,15 +467,21 @@ Status InsertNewOpUtil::RecordAIPPInfoToData(const ComputeGraphPtr &graph) { // When static aipp is set, need to get the model input dims which processed by aipp GE_RETURN_IF_ERROR(SetModelInputDims(data_node, aipp_it)); } - - if (!AttrUtils::SetListStr(data_node->GetOpDesc(), ATTR_NAME_AIPP_INPUTS, input_dims)) { - GELOGE(FAILED, "SetListStr of %s failed.", ATTR_NAME_AIPP_INPUTS.c_str()); - return FAILED; - } - - if (!AttrUtils::SetListStr(data_node->GetOpDesc(), ATTR_NAME_AIPP_OUTPUTS, output_dims)) { - GELOGE(FAILED, "SetListStr of %s failed.", ATTR_NAME_AIPP_OUTPUTS.c_str()); - return FAILED; + // if _all_origin_gears_inputs is set, use its value directly. + if (data_op_desc->HasAttr("_all_origin_gears_inputs")) { + std::vector input_dims_str; + (void)AttrUtils::GetListStr(data_op_desc, "_all_origin_gears_inputs", input_dims_str); + (void)AttrUtils::SetListStr(data_op_desc, ATTR_NAME_AIPP_INPUTS, input_dims_str); + if ((input_dims_str.size() > output_dims.size()) && !output_dims.empty()) { + // make sure output and input counts is equal, appears in dynamic aipp and dynamic shape/batch scene. + std::vector output_dims_str{input_dims_str.size(), output_dims[0]}; + (void)AttrUtils::SetListStr(data_op_desc, ATTR_NAME_AIPP_OUTPUTS, output_dims_str); + } else { + (void)AttrUtils::SetListStr(data_op_desc, ATTR_NAME_AIPP_OUTPUTS, output_dims); + } + } else { + (void)AttrUtils::SetListStr(data_op_desc, ATTR_NAME_AIPP_INPUTS, input_dims); + (void)AttrUtils::SetListStr(data_op_desc, ATTR_NAME_AIPP_OUTPUTS, output_dims); } } @@ -550,7 +554,7 @@ Status InsertNewOpUtil::SetModelInputDims(NodePtr &data_node, NodePtr &aipp_node } } } - GELOGD("After set H/W to -1, the model input dims: %s.", formats::JoinToString(model_input_dims).c_str()); + GELOGD("After set N or H/W to -1, the model input dims: %s.", formats::JoinToString(model_input_dims).c_str()); if (!AttrUtils::SetListInt(data_opdesc, ATTR_NAME_INPUT_DIMS, model_input_dims)) { GELOGE(FAILED, "SetListInt of %s failed.", ATTR_NAME_INPUT_DIMS.c_str()); return FAILED; diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.cc b/src/ge/graph/preprocess/multi_batch_copy_graph.cc index d1e9fe62..8a066b6a 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -34,7 +34,13 @@ #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/preprocess/multi_batch_options.h" +#include "inc/pass_manager.h" +#include "graph/passes/multi_batch_clone_pass.h" +using std::set; using std::string; using std::vector; @@ -48,9 +54,6 @@ const int kDataOutIndex = 0; const int kDataInIndex = 0; const int kMergeDataOutIndex = 0; const int kStaticOutput = -1; -const int kDecimal = 10; -const size_t kMaxShapesCount = 100; -const size_t kMinShapesCount = 2; inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } @@ -90,7 +93,8 @@ NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const return graph->AddNode(desc); } -NodePtr InsertCopyNode(const NodePtr &node, const std::string &name) { +NodePtr InsertCopyNode(const NodePtr &node, size_t n) { + const std::string &name = node->GetName() + "_ascend_mbatch_batch_" + std::to_string(n); auto src_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(src_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "Failed to copy node %s to %s, the OpDesc is null", node->GetName().c_str(), name.c_str()); @@ -121,39 +125,16 @@ NodePtr InsertCopyNode(const NodePtr &node, const std::string &name) { output_desc->CopyAttrsFrom(src_op_desc->GetOutputDesc(i)); } + const std::string &batch_label = "Batch_" + std::to_string(n); + if (!AttrUtils::SetStr(desc, ATTR_NAME_BATCH_LABEL, batch_label)) { + GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", name.c_str()); + return nullptr; + } + auto graph = node->GetOwnerComputeGraph(); return graph->AddNode(desc); } -Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { - size_t batch_shape_index = 0; - for (size_t i = 0; i < data_shape.GetDimNum(); ++i) { - if (data_shape.GetDim(i) < 0) { - if (batch_shape_index >= batch_shape.size()) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19012", {"function", "reason"}, - {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + - " does not match the data shape " + data_shape.ToString()}); - GELOGE(PARAM_INVALID, - "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", - batch_shape.size(), data_shape.ToString().c_str()); - return PARAM_INVALID; - } - data_shape.SetDim(i, batch_shape[batch_shape_index++]); - } - } - if (batch_shape_index != batch_shape.size()) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19012", {"function", "reason"}, - {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + " does not match the data shape " + - data_shape.ToString()}); - GELOGE(PARAM_INVALID, "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", - batch_shape.size(), data_shape.ToString().c_str()); - return PARAM_INVALID; - } - return SUCCESS; -} - bool IsAllDimsPositive(const std::vector &dims) { for (auto dim : dims) { if (dim < 0) { @@ -232,6 +213,11 @@ Status MultiBatchGraphCopyer::CopyGraph() { return ret; } + if (LabelStatus() != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to label status for all nodes."); + return INTERNAL_ERROR; + } + ret = CreateNewNodes(); if (ret != SUCCESS) { return ret; @@ -266,6 +252,55 @@ Status MultiBatchGraphCopyer::Init() { } return SUCCESS; } + +Status MultiBatchGraphCopyer::LabelStatus() { + for (const auto &data : origin_data_nodes_) { + origin_nodes_status_[data.get()] = kNodeInBatchBranch; + } + bool changed = true; + // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch + while (changed) { + changed = false; + for (const auto &node : origin_all_nodes_) { + auto iter = origin_nodes_status_.find(node.get()); + if (iter != origin_nodes_status_.end()) { + continue; + } + for (auto &in_node : node->GetInAllNodes()) { + if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() && + origin_nodes_status_[in_node.get()] == kNodeInBatchBranch) { + origin_nodes_status_[node.get()] = kNodeInBatchBranch; + changed = true; + break; + } + } + } + } + + for (const auto &node : origin_all_nodes_) { + if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) { + origin_nodes_status_[node.get()] = kNodeNotSupportNode; + continue; + } + if (node->GetType() == NETOUTPUT) { + origin_nodes_status_[node.get()] = kNodeOutBatchBranch; + continue; + } + if (IsDataLikeType(node->GetType())) { + if (IsOnlyOutputToAipp(node)) { + origin_nodes_status_[node.get()] = kNodeOutBatchBranch; + } else { + origin_nodes_status_[node.get()] = kNodeStartNode; + } + continue; + } + if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { + origin_nodes_status_[node.get()] = kNodeOutBatchBranch; + } + } + return SUCCESS; +} + Status MultiBatchGraphCopyer::CreateNewNodes() { shape_data_ = InsertShapeDataNode(); if (shape_data_ == nullptr) { @@ -280,18 +315,22 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { GELOGD("Process node %s, status %d", node->GetName().c_str(), static_cast(branch_status)); switch (branch_status) { case kNodeStartNode: + GELOGD("Name: %s, type: %s, status: kNodeStartNode.", node->GetName().c_str(), node->GetType().c_str()); ret = InsertSwitchNForData(node); if (ret == SUCCESS) { ret = UpdateMaxShapeToData(node); } break; case kNodeInBatchBranch: + GELOGD("Name: %s, type: %s, status: kNodeInBatchBranch.", node->GetName().c_str(), node->GetType().c_str()); ret = CopyNodeInBatchBranch(node); break; case kNodeOutBatchBranch: + GELOGD("Name: %s, type: %s, status: kNodeOutBatchBranch.", node->GetName().c_str(), node->GetType().c_str()); ret = InsertMergeForEdgeNode(node); break; case kNodeNotSupportNode: + GELOGD("Name: %s, type: %s, status: kNodeNotSupportNode.", node->GetName().c_str(), node->GetType().c_str()); break; default: GELOGE(INTERNAL_ERROR, "Unexpected status %d on node %s", static_cast(branch_status), @@ -306,26 +345,6 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { return SUCCESS; } -NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { - // node with subgraph is not supported - if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) { - return kNodeNotSupportNode; - } - - if (node->GetType() == NETOUTPUT) { - return kNodeOutBatchBranch; - } - if (IsDataLikeType(node->GetType()) && !IsOnlyOutputToAipp(node)) { - return kNodeStartNode; - } - for (auto &in_node : node->GetInDataNodes()) { - if (IsInBatchBranch(in_node)) { - return kNodeInBatchBranch; - } - } - return kNodeOutBatchBranch; -} - NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) { if (index < 0) { // the merge node must has data inputs, if origin connection is a control @@ -497,52 +516,8 @@ Status MultiBatchGraphCopyer::CheckArguments() { GELOGE(PARAM_INVALID, "Failed to copy graph, the graph is null"); return PARAM_INVALID; } - if (shapes_.size() < kMinShapesCount) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10035", {"shapesize", "minshapesize"}, {std::to_string(shapes_.size()), std::to_string(kMinShapesCount - 1)}); - GELOGE(PARAM_INVALID, - "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " - "value size [%zu] must be greater than [%zu].", - shapes_.size(), kMinShapesCount - 1); - return PARAM_INVALID; - } - if (shapes_.size() > kMaxShapesCount) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10036", {"shapesize", "maxshapesize"}, {std::to_string(shapes_.size()), std::to_string(kMaxShapesCount + 1)}); - GELOGE(PARAM_INVALID, - "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " - "value size [%zu] must be less than [%zu].", - shapes_.size(), kMaxShapesCount + 1); - return PARAM_INVALID; - } - std::set> shapes_set; - size_t shape_size = shapes_.at(0).size(); - for (auto &shape : shapes_) { - if (shape_size != shape.size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"shapesize1", "shapesize2"}, - {std::to_string(shape_size), std::to_string(shape.size())}); - GELOGE(PARAM_INVALID, - "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " - "value size must be same, first group's size is %zu and another's is %zu.", - shape_size, shape.size()); - return PARAM_INVALID; - } - for (auto dim : shape) { - if (dim <= 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"dim"}, {std::to_string(dim)}); - GELOGE(PARAM_INVALID, "Invalid dim %ld, all dims must be greater than 0", dim); - return PARAM_INVALID; - } - } - shapes_set.insert(shape); - } - if (shapes_set.size() != shapes_.size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10039"); - GELOGE(PARAM_INVALID, - "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims] exist duplicate shapes."); - return PARAM_INVALID; - } - return SUCCESS; + + return CheckDynamicParams(shapes_); } Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector &start_nodes) { for (auto &node : start_nodes) { @@ -564,7 +539,7 @@ bool MultiBatchGraphCopyer::IsInBatchBranch(const NodePtr &node) { Status MultiBatchGraphCopyer::LinkDataToMerge(const NodePtr &data, const NodePtr &merge) { // The caller should make sure that the there is a SwitchN node in the map auto &switchn = data_nodes_to_switchn_[data.get()]; - GELOGI("Link edge bwetween data %s to merge %s throw switchn %s", data->GetName().c_str(), merge->GetName().c_str(), + GELOGI("Link edge between data %s to merge %s throw switchn %s", data->GetName().c_str(), merge->GetName().c_str(), switchn->GetName().c_str()); for (size_t i = 0; i < shapes_.size(); ++i) { auto ret = GraphUtils::AddEdge(switchn->GetOutDataAnchor(i), merge->GetInDataAnchor(i)); @@ -617,17 +592,18 @@ Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index } Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) { auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); + auto data_name = data->GetName(); if (IsAllDimsPositive(data_shape.GetDims())) { return SUCCESS; } - size_t max_shape_index = 0; int64_t max_size = 0; for (size_t i = 0; i < shapes_.size(); ++i) { int64_t size = 1; - for (auto dim : shapes_[i]) { + for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) { if (INT64_MAX / dim < size) { - GELOGE(PARAM_INVALID, "The shape %s size overflow", formats::ShapeToString(shapes_[i]).c_str()); + GELOGE(PARAM_INVALID, "The shape %s size overflow", + formats::ShapeToString(data_to_dynamic_info_[data_name].at(i)).c_str()); return PARAM_INVALID; } size *= dim; @@ -637,10 +613,8 @@ Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) { max_shape_index = i; } } - // must not be error, the calc result has been checked in function InsertSwitchNForData - (void)CalcShape(shapes_[max_shape_index], data_shape); - + (void)CalcShape(data_to_dynamic_info_.at(data_name).at(max_shape_index), data_shape); auto ret = NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape); if (ret != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); @@ -657,6 +631,7 @@ Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) { } Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); + auto data_name = data->GetName(); (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); if (IsAllDimsPositive(data_shape.GetDims())) { @@ -681,30 +656,41 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { if (switchn_desc->AddInputDesc("pred_value", pred_tensor) != GRAPH_SUCCESS) { // pred return OUT_OF_MEMORY; } + std::vector input_dims_str; for (size_t i = 0; i < shapes_.size(); ++i) { auto shape = data_shape; - auto ret = CalcShape(shapes_.at(i), shape); + auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); if (ret != SUCCESS) { GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match", data->GetName().c_str()); return ret; } tensor.SetShape(shape); + string input_str; + int64_t tensor_size = 0; + (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); + input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + + TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + + std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + + formats::JoinToString(tensor.GetShape().GetDims()); + input_dims_str.emplace_back(input_str); if (!AttrUtils::SetListInt(tensor, ATTR_NAME_SWITCHN_PRED_VALUE, shapes_.at(i))) { GELOGE(INTERNAL_ERROR, "Failed to add attr value on output %zu tensor", i); return INTERNAL_ERROR; } - if (!AttrUtils::SetListInt(tensor, ATTR_NAME_COMBINED_DYNAMIC_DIMS, shape.GetDims())) { - GELOGE(INTERNAL_ERROR, "Failed to add attr ATTR_NAME_COMBINED_DYNAMIC_DIMS on output %zu tensor", i); - return INTERNAL_ERROR; - } + (void)AttrUtils::SetListInt(tensor, ATTR_NAME_COMBINED_DYNAMIC_DIMS, shape.GetDims()); if (switchn_desc->AddOutputDesc("output" + std::to_string(i), tensor) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Opdesc AddOutputDesc failed"); return GRAPH_FAILED; } GELOGD("The SwitchN %s output index %zu, shape %s", switchn_desc->GetName().c_str(), i, shape.ToString().c_str()); } - + (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); + if (!AttrUtils::SetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) { + GELOGE(INTERNAL_ERROR, "Failed to add user designate shape order attr on switchn node %s", + switchn_desc->GetName().c_str()); + return INTERNAL_ERROR; + } if (!AttrUtils::SetBool(switchn_desc, ATTR_INSERT_BY_MBATCH, true)) { GELOGE(INTERNAL_ERROR, "Failed to add insert attr on switchn node %s", switchn_desc->GetName().c_str()); return INTERNAL_ERROR; @@ -713,7 +699,7 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", data->GetName().c_str()); return INTERNAL_ERROR; } - if (StampDynamicTypeForSwitchN(switchn_desc) != SUCCESS) { + if (StampDynamicType(switchn_desc) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr on switchn node %s", switchn_desc->GetName().c_str()); return INTERNAL_ERROR; } @@ -727,25 +713,6 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { return SUCCESS; } -Status MultiBatchGraphCopyer::StampDynamicTypeForSwitchN(OpDescPtr &switchn_desc) { - GE_CHECK_NOTNULL(switchn_desc); - int32_t dynamic_type = static_cast(FIXED); - if (!domi::GetContext().dynamic_batch_size.empty()) { - dynamic_type = static_cast(DYNAMIC_BATCH); - } - if (!domi::GetContext().dynamic_image_size.empty()) { - dynamic_type = static_cast(DYNAMIC_IMAGE); - } - if (!domi::GetContext().dynamic_dims.empty()) { - dynamic_type = static_cast(DYNAMIC_DIMS); - } - if (!AttrUtils::SetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { - GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr of switchn node %s", switchn_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) { for (auto &in_data_anchor : node->GetAllInDataAnchors()) { auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); @@ -778,7 +745,7 @@ Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) { Status MultiBatchGraphCopyer::CopyNodeInBatchBranch(const NodePtr &node) { auto ©ed_nodes = nodes_to_batch_nodes_[node.get()]; for (size_t i = 0; i < shapes_.size(); ++i) { - auto copyed_node = InsertCopyNode(node, node->GetName() + "_ascend_mbatch_batch_" + std::to_string(i)); + auto copyed_node = InsertCopyNode(node, i); if (copyed_node == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to add node to graph when copy node %s", node->GetName().c_str()); return INTERNAL_ERROR; @@ -945,122 +912,182 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { } Status ProcessMultiBatch(ComputeGraphPtr &graph) { - std::vector> shapes; - if (!domi::GetContext().dynamic_batch_size.empty()) { - GELOGD("Found dynamic batch option, value %s", domi::GetContext().dynamic_batch_size.c_str()); - std::vector dims = ge::StringUtils::Split(domi::GetContext().dynamic_batch_size, ','); - for (const auto &dim : dims) { - if (dim.empty()) { - continue; - } - shapes.emplace_back(std::vector({std::strtol(dim.c_str(), nullptr, kDecimal)})); - GELOGI("Found dynamic batch, shape %s", formats::JoinToString(*shapes.rbegin()).c_str()); - } - } - - if (!domi::GetContext().dynamic_image_size.empty()) { - GELOGD("Found dynamic image size option, value %s", domi::GetContext().dynamic_image_size.c_str()); - ParseDynamicSize(domi::GetContext().dynamic_image_size, shapes); - - for (const auto &shape : shapes) { - GELOGI("Found dynamic image size, shape %s", formats::JoinToString(shape).c_str()); - } + const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); + if (multi_batch_with_case != nullptr) { + PassManager pass_manager; + GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); + return pass_manager.Run(graph); } - if (!domi::GetContext().dynamic_dims.empty()) { - GELOGD("Found dynamic dims option, value %s", domi::GetContext().dynamic_dims.c_str()); - ParseDynamicSize(domi::GetContext().dynamic_dims, shapes); - - for (const auto &shape : shapes) { - GELOGI("Found dynamic dims, shape %s", formats::JoinToString(shape).c_str()); - } - } - - if (shapes.empty()) { + std::vector> shapes; + if (!InitDynamicParams(shapes)) { GELOGD("There is no multi-batch options, no need to process multi-batch copy"); return SUCCESS; } + map>> data_to_dynamic_info; + if (ParserDataToDynmaicInfo(shapes, data_to_dynamic_info) != SUCCESS) { + GELOGD("Parse each data's own dynamic info failed"); + return SUCCESS; + } + + std::vector>> user_designate_shape; + user_designate_shape = domi::GetContext().user_input_dims; GELOGI("Begin to copy graph for multi-batch"); multibatch::MultiBatchGraphCopyer copyer(graph); for (auto &shape : shapes) { copyer.AddShape(shape); } + copyer.SetUserDesignateShape(user_designate_shape); + copyer.SetDataToDynamicInfo(data_to_dynamic_info); return copyer.CopyGraph(); } -void ParseDynamicSize(string dynamic_size, vector> &shapes) { - std::vector shape_strs = ge::StringUtils::Split(dynamic_size, ';'); - for (const auto &shape_str : shape_strs) { - if (shape_str.empty()) { - continue; +// +-----------+ +// | Data | +-----------+ +-----------+ +-----------+ +// +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput | +// \ /. +-----------+ +-----------+ +-----------+ +// \ /. +// +-----------+ +-----------+ /. +-----------+ +-----------+ +-----------+ +// | Data | ----> | Case | S--- | Data | ----> | SoftmaxV2 | ----> | NetOutput | +// +-----------+ +-----------+ \. +-----------+ +-----------+ +-----------+ +// \ \. +// \ \. +-----------+ +-----------+ +-----------+ +// +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput | +// | NetOutput | +-----------+ +-----------+ +-----------+ +// +-----------+ +// +-----------+ / +// | Data | --------------->/ +// +-----------+ +void GetDynamicShapeByGraph(const ComputeGraphPtr &graph, const NodePtr &node, set &dynamic_output_index, + vector &dynamic_output_dims) { + GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str()); + const auto &func_desc = node->GetOpDesc(); + if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { + GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), node->GetName().c_str()); + return; + } + + const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames(); + for (size_t i = 0; i < func_desc->GetOutputsSize(); ++i) { + for (size_t j = 0; j < dynamic_branch_names.size(); ++j) { + const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[j]); + if (subgraph == nullptr) { + GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s", dynamic_branch_names[j].c_str()); + dynamic_output_dims.clear(); + return; + } + + const auto &out_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); + if (out_node == nullptr) { + GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "NetOutput not found, name: %s", dynamic_branch_names[j].c_str()); + dynamic_output_dims.clear(); + return; + } + + GELOGI("Find the subgraph Output node %s and the index is %zu", out_node->GetName().c_str(), i); + const auto &out_desc = out_node->GetOpDesc(); + if (out_desc == nullptr || out_desc->GetInputsSize() <= i) { + GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Get Input desc failed, name: %s, index: %zu", out_node->GetName().c_str(), i); + dynamic_output_dims.clear(); + return; + } + + const auto &input_tensor = out_desc->GetInputDesc(i); + const auto &shape_msg = input_tensor.GetShape().ToString(); + string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg; + GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str()); + dynamic_output_dims.emplace_back(output_shape); + + uint32_t parent_index = 0; + (void)AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, parent_index); + dynamic_output_index.insert(parent_index); } - std::vector shape; - std::vector dims = ge::StringUtils::Split(shape_str, ','); - for (const auto &dim : dims) { - if (dim.empty()) { - continue; + } +} + +// +-----------+ +-----------+ i = 0 +// +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> \. +// / +-----------+ +-----------+ \. +// / \. +// +-----------+ +-----------+ +-----------+ +-----------+ i = 1 +-----------+ +// | Data | ----> | SwitchN | ----> | SoftmaxV2 | ----> |MemcpyAsync| ----> | Merge | +// +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ +// \ / \. j = 0 +// \ +-----------+ +-----------+ i = 2 / \. +// +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> / +-----------+ +// +-----------+ +-----------+ | NetOutput | +// +-----------+ +// +-----------+ /. +// | Data | --------------------------------------------------------------------------->/. j = 1 +// +-----------+ +void GetDynamicShapeByMerge(const ComputeGraphPtr &graph, const NodePtr &node, set &dynamic_output_index, + vector &dynamic_output_dims) { + GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str()); + const auto &netoutput_desc = node->GetOpDesc(); + const auto &inputnode_to_netoutput = node->GetInAllNodes(); + for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) { + bool insert_by_mbatch = false; + (void)AttrUtils::GetBool(inputnode_to_netoutput.at(i)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, insert_by_mbatch); + if (inputnode_to_netoutput.at(i)->GetType() == MERGE && insert_by_mbatch) { + GELOGI("Find the merge node %s with mbatch attr and the index is %zu", + inputnode_to_netoutput.at(i)->GetName().c_str(), i); + dynamic_output_index.insert(i); + for (size_t j = 0; j < inputnode_to_netoutput.at(i)->GetInNodes().size(); ++j) { + auto input_desc = inputnode_to_netoutput.at(i)->GetOpDesc(); + auto input_tensor_desc = input_desc->GetInputDesc(j); + auto shape_msg = input_tensor_desc.GetShape().ToString(); + string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg; + GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str()); + dynamic_output_dims.emplace_back(output_shape); } - shape.emplace_back(std::strtol(dim.c_str(), nullptr, kDecimal)); } - if (!shape.empty()) { - shapes.emplace_back(shape); + } +} + +// Connect NetOutput directly: DTS2020070612498 +void GetDirectOutputShape(const ComputeGraphPtr &graph, const NodePtr &node, const set &dynamic_output_index, + vector &dynamic_output_dims) { + GELOGD("Try get directly shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str()); + const auto &netoutput_desc = node->GetOpDesc(); + const auto &inputnode_to_netoutput = node->GetInAllNodes(); + for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) { + if (dynamic_output_index.count(i) > 0) { + continue; } + + auto tensor_desc = netoutput_desc->GetInputDesc(i); + auto shape = tensor_desc.GetShape().ToString(); + string static_output_shape = std::to_string(kStaticOutput) + "," + std::to_string(i) + "," + shape; + GELOGI("The static output shape msg is %s", static_output_shape.c_str()); + dynamic_output_dims.emplace_back(static_output_shape); } } Status GetDynamicOutputShape(ComputeGraphPtr &graph) { - GELOGI("Start to get dynamic output dynamic batch shape msg"); - std::vector dynamic_output_dims; - if (graph == nullptr) { - GELOGE(PARAM_INVALID, "Graph is null ,para is invalid"); - return PARAM_INVALID; - } + GE_CHECK_NOTNULL(graph); + GELOGI("Start to get output dynamic batch shape message"); + + NodePtr net_output; + set dynamic_output_index; + vector dynamic_output_dims; for (auto &node : graph->GetDirectNode()) { if (node->GetType() == NETOUTPUT) { - auto netoutput_desc = node->GetOpDesc(); - auto inputnode_to_netoutput = node->GetInAllNodes(); - std::vector dynamic_output_index; - for (size_t j = 0; j < inputnode_to_netoutput.size(); j++) { - bool ret = false; - (void)AttrUtils::GetBool(inputnode_to_netoutput.at(j)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, ret); - if (inputnode_to_netoutput.at(j)->GetType() == MERGE && ret) { - GELOGI("Find the merge node %s with mbatch attr and the index is %zu", - inputnode_to_netoutput.at(j)->GetName().c_str(), j); - dynamic_output_index.emplace_back(j); - for (size_t i = 0; i < inputnode_to_netoutput.at(j)->GetInNodes().size(); i++) { - auto input_desc = inputnode_to_netoutput.at(j)->GetOpDesc(); - auto input_tensor_desc = input_desc->GetInputDesc(i); - auto shape_msg = input_tensor_desc.GetShape().ToString(); - std::string output_shape = std::to_string(i) + "," + std::to_string(j) + "," + shape_msg; - GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str()); - dynamic_output_dims.emplace_back(output_shape); - } - } - } - if (dynamic_output_dims.size() > 0) { - for (size_t k = 0; k < inputnode_to_netoutput.size(); k++) { - auto it = std::find(dynamic_output_index.begin(), dynamic_output_index.end(), k); - if (it != dynamic_output_index.end()) { - continue; - } - auto tensor_desc = netoutput_desc->GetInputDesc(k); - auto shape = tensor_desc.GetShape().ToString(); - std::string static_output_shape = std::to_string(kStaticOutput) + "," + std::to_string(k) + "," + shape; - GELOGI("The static output shape msg is %s", static_output_shape.c_str()); - dynamic_output_dims.emplace_back(static_output_shape); - } - if (!AttrUtils::SetListStr(netoutput_desc, ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims)) { - GELOGE(FAILED, "Set dynamic output dims attr failed"); - return FAILED; - } - return SUCCESS; - } - GELOGI("Can not find the merge node with mbatch attr"); - return SUCCESS; + net_output = node; + GetDynamicShapeByMerge(graph, node, dynamic_output_index, dynamic_output_dims); + } else if (node->GetType() == CASE) { + GetDynamicShapeByGraph(graph, node, dynamic_output_index, dynamic_output_dims); } } - GELOGW("There are no netoutput in graph"); + + if ((net_output != nullptr) && !dynamic_output_dims.empty()) { + GetDirectOutputShape(graph, net_output, dynamic_output_index, dynamic_output_dims); + if (!AttrUtils::SetListStr(net_output->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims)) { + GELOGE(FAILED, "Set dynamic output dims attr failed"); + return FAILED; + } + } + return SUCCESS; } } // namespace multibatch diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.h b/src/ge/graph/preprocess/multi_batch_copy_graph.h index 7e317cb0..a0e61554 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.h +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.h @@ -27,7 +27,6 @@ namespace ge { namespace multibatch { Status ProcessMultiBatch(ComputeGraphPtr &graph); -void ParseDynamicSize(std::string dynamic_size, std::vector> &shapes); Status GetDynamicOutputShape(ComputeGraphPtr &graph); @@ -44,19 +43,28 @@ class MultiBatchGraphCopyer { ~MultiBatchGraphCopyer() = default; void AddShape(const std::vector &shape) { shapes_.emplace_back(shape); } - + void SetUserDesignateShape(const vector>> &designate_shape) { + user_designate_shape_ = designate_shape; + for (auto &item : designate_shape) { + data_name_order_.push_back(item.first); + } + } + void SetDataToDynamicInfo(const map>> &designate_shape) { + data_to_dynamic_info_ = designate_shape; + } Status CopyGraph(); private: Status Init(); Status CheckArguments(); + // label status for origin_all_nodes_ + Status LabelStatus(); // add nodes functions Status CreateNewNodes(); NodePtr InsertShapeDataNode(); Status InsertSwitchNForData(const NodePtr &data); - Status StampDynamicTypeForSwitchN(OpDescPtr &switchn_desc); Status UpdateMaxShapeToData(const NodePtr &data); Status InsertMergeForEdgeNode(const NodePtr &node); @@ -87,7 +95,7 @@ class MultiBatchGraphCopyer { Status CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr ©ed_node); bool IsInBatchBranch(const NodePtr &node); - NodeStatus GetNodeStatus(const NodePtr &node); + NodeStatus GetNodeStatus(const NodePtr &node) { return origin_nodes_status_[node.get()]; }; Status CheckCopyResult(const std::vector &start_nodes); // arguments @@ -111,6 +119,16 @@ class MultiBatchGraphCopyer { // the nodes on the in/out-batch-branch edge, and the merge nodes inserted after it std::map> nodes_to_merge_nodes_; + + // all nodes and their status + std::map origin_nodes_status_; + + // user designate shape, decord the order of each input data + std::vector>> user_designate_shape_; + std::vector data_name_order_; + + // each data's own dynamic info + map>> data_to_dynamic_info_; }; } // namespace multibatch } // namespace ge diff --git a/src/ge/graph/preprocess/multi_batch_options.cc b/src/ge/graph/preprocess/multi_batch_options.cc new file mode 100644 index 00000000..cbf8206f --- /dev/null +++ b/src/ge/graph/preprocess/multi_batch_options.cc @@ -0,0 +1,258 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "multi_batch_options.h" + +#include "framework/common/debug/ge_log.h" +#include "framework/omg/omg_inner_types.h" +#include "framework/common/util.h" +#include "framework/common/string_util.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "common/util/error_manager/error_manager.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/node_utils.h" +#include "graph/ge_context.h" + +namespace ge { +namespace multibatch { +constexpr int kDecimal = 10; +constexpr uint8_t kMaxShapesCount = 100; +constexpr uint8_t kMinShapesCount = 2; + +void ParseDynamicSize(string dynamic_size, vector> &shapes) { + std::vector shape_strs = ge::StringUtils::Split(dynamic_size, ';'); + for (const auto &shape_str : shape_strs) { + if (shape_str.empty()) { + continue; + } + std::vector shape; + std::vector dims = ge::StringUtils::Split(shape_str, ','); + for (const auto &dim : dims) { + if (dim.empty()) { + continue; + } + shape.emplace_back(std::strtol(dim.c_str(), nullptr, kDecimal)); + } + if (!shape.empty()) { + shapes.emplace_back(shape); + } + } +} + +/// +/// @ingroup ge +/// @brief Init Dynamic Param from Options. +/// @param [out] std::vector> &shapes: Result for Params. +/// @return true: Configed for Multi batch / false: Not configed for Multi batch. +/// +bool InitDynamicParams(vector> &shapes) { + if (!domi::GetContext().dynamic_batch_size.empty()) { + GELOGD("Found dynamic batch option, value %s", domi::GetContext().dynamic_batch_size.c_str()); + std::vector dims = ge::StringUtils::Split(domi::GetContext().dynamic_batch_size, ','); + for (const auto &dim : dims) { + if (dim.empty()) { + continue; + } + shapes.emplace_back(std::vector({std::strtol(dim.c_str(), nullptr, kDecimal)})); + GELOGI("Found dynamic batch, shape %s", formats::JoinToString(*shapes.rbegin()).c_str()); + } + } + + if (!domi::GetContext().dynamic_image_size.empty()) { + GELOGD("Found dynamic image size option, value %s", domi::GetContext().dynamic_image_size.c_str()); + ParseDynamicSize(domi::GetContext().dynamic_image_size, shapes); + + for (const auto &shape : shapes) { + GELOGI("Found dynamic image size, shape %s", formats::JoinToString(shape).c_str()); + } + } + + if (!domi::GetContext().dynamic_dims.empty()) { + GELOGD("Found dynamic dims option, value %s", domi::GetContext().dynamic_dims.c_str()); + ParseDynamicSize(domi::GetContext().dynamic_dims, shapes); + + for (const auto &shape : shapes) { + GELOGI("Found dynamic dims, shape %s", formats::JoinToString(shape).c_str()); + } + } + + return !shapes.empty(); +} + +/// +/// @ingroup ge +/// @brief parse each data's own dynamic dims. +/// @param [out] map>> &data_to_dynamic_info: key:data_name. value:dynamic dims. +/// @return true: Configed for Multi batch / false: Not configed for Multi batch. +/// +Status ParserDataToDynmaicInfo(const vector> &shapes, + map>> &data_to_dynamic_info) { + if (domi::GetContext().user_input_dims.empty()) { + GELOGD("Get user designed shape failed"); + return FAILED; + } + size_t cur_data_index = 0; + for (size_t index = 0; index < domi::GetContext().user_input_dims.size(); ++index) { + auto &cur_item = domi::GetContext().user_input_dims[index]; + auto &data_name = cur_item.first; + auto &data_shape = cur_item.second; + auto dynamic_dims_num = + std::count_if(data_shape.begin(), data_shape.end(), [&data_shape](int64_t dim) { return dim < 0; }); + vector> dynamic_info; + for (auto &dynamic_gear_info : shapes) { + vector one_gear; + if (dynamic_gear_info.size() == static_cast(dynamic_dims_num)) { + one_gear = dynamic_gear_info; + } else if (dynamic_gear_info.size() > static_cast(dynamic_dims_num)) { + auto tmp_index = cur_data_index; + for (size_t i = 0; i < static_cast(dynamic_dims_num); ++i) { + if (tmp_index >= dynamic_gear_info.size()) { + GELOGE(PARAM_INVALID, "Data: %s shape: %s make dynamic dims overflow", data_name.c_str(), + formats::JoinToString(data_shape).c_str()); + return FAILED; + } + one_gear.push_back(dynamic_gear_info[tmp_index++]); + } + } else { + GELOGE(PARAM_INVALID, "Dynamic dims num of data: %s shape: %s can not be more than one gear dynamic info size", + data_name.c_str(), formats::JoinToString(data_shape).c_str()); + return FAILED; + } + dynamic_info.push_back(one_gear); + } + cur_data_index += dynamic_dims_num; + data_to_dynamic_info[data_name] = dynamic_info; + } + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Check Dynamic Param is invalid. +/// @param [in] const vector> &shapes: Params for check. +/// @return SUCCESS: valid / PARAM_INVALID: invalid. +/// +Status CheckDynamicParams(const vector> &shapes) { + if (shapes.size() < kMinShapesCount) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10035", {"shapesize", "minshapesize"}, {std::to_string(shapes.size()), std::to_string(kMinShapesCount - 1)}); + GELOGE(PARAM_INVALID, + "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " + "value size [%zu] must be greater than [%zu].", + shapes.size(), kMinShapesCount - 1); + return PARAM_INVALID; + } + if (shapes.size() > kMaxShapesCount) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10036", {"shapesize", "maxshapesize"}, {std::to_string(shapes.size()), std::to_string(kMaxShapesCount + 1)}); + GELOGE(PARAM_INVALID, + "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " + "value size [%zu] must be less than [%zu].", + shapes.size(), kMaxShapesCount + 1); + return PARAM_INVALID; + } + std::set> shapes_set; + size_t shape_size = shapes.at(0).size(); + for (auto &shape : shapes) { + if (shape_size != shape.size()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"shapesize1", "shapesize2"}, + {std::to_string(shape_size), std::to_string(shape.size())}); + GELOGE(PARAM_INVALID, + "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " + "value size must be same, first group's size is %zu and another's is %zu.", + shape_size, shape.size()); + return PARAM_INVALID; + } + for (auto dim : shape) { + if (dim <= 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"dim"}, {std::to_string(dim)}); + GELOGE(PARAM_INVALID, "Invalid dim %ld, all dims must be greater than 0", dim); + return PARAM_INVALID; + } + } + shapes_set.insert(shape); + } + if (shapes_set.size() != shapes.size()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10039"); + GELOGE(PARAM_INVALID, + "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims] exist duplicate shapes."); + return PARAM_INVALID; + } + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Get GeShape from configed shape. +/// @param [in] const std::vector &batch_shape: Configed shape. +/// @param [out] GeShape &data_shape: GeShape for configed shape. +/// @return SUCCESS / PARAM_INVALID +/// +Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { + size_t batch_shape_index = 0; + for (size_t i = 0; i < data_shape.GetDimNum(); ++i) { + if (data_shape.GetDim(i) < 0) { + if (batch_shape_index >= batch_shape.size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19012", {"function", "reason"}, + {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + + " does not match the data shape " + data_shape.ToString()}); + GELOGE(PARAM_INVALID, + "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", + batch_shape.size(), data_shape.ToString().c_str()); + return PARAM_INVALID; + } + data_shape.SetDim(i, batch_shape[batch_shape_index++]); + } + } + if (batch_shape_index != batch_shape.size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19012", {"function", "reason"}, + {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + " does not match the data shape " + + data_shape.ToString()}); + GELOGE(PARAM_INVALID, "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", + batch_shape.size(), data_shape.ToString().c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Set mbatch_dynamic_type on node. +/// @param [in] const OpDescPtr &op_desc: Node for set attribute. +/// @return 0: SUCCESS / others: INTERNAL_ERROR +/// +Status StampDynamicType(const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_desc); + int32_t dynamic_type = static_cast(FIXED); + if (!domi::GetContext().dynamic_batch_size.empty()) { + dynamic_type = static_cast(DYNAMIC_BATCH); + } + if (!domi::GetContext().dynamic_image_size.empty()) { + dynamic_type = static_cast(DYNAMIC_IMAGE); + } + if (!domi::GetContext().dynamic_dims.empty()) { + dynamic_type = static_cast(DYNAMIC_DIMS); + } + if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { + GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr for node %s", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + return SUCCESS; +} +} // namespace multibatch +} // namespace ge diff --git a/src/ge/graph/preprocess/multi_batch_options.h b/src/ge/graph/preprocess/multi_batch_options.h new file mode 100644 index 00000000..650020d9 --- /dev/null +++ b/src/ge/graph/preprocess/multi_batch_options.h @@ -0,0 +1,72 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PREPROCESS_MULTI_BATCH_OPTIONS_H_ +#define GE_GRAPH_PREPROCESS_MULTI_BATCH_OPTIONS_H_ + +#include + +#include "external/ge/ge_api_error_codes.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/node.h" + +namespace ge { +namespace multibatch { +/// +/// @ingroup ge +/// @brief Init Dynamic Param from Options. +/// @param [out] std::vector> &shapes: Result for Params. +/// @return true: Configed for Multi batch / false: Not configed for Multi batch. +/// +bool InitDynamicParams(std::vector> &shapes); + +/// +/// @ingroup ge +/// @brief Check Dynamic Param is invalid. +/// @param [in] const vector> &shapes: Params for check. +/// @return SUCCESS: valid / PARAM_INVALID: invalid. +/// +Status CheckDynamicParams(const std::vector> &shapes); + +/// +/// @ingroup ge +/// @brief Get GeShape from configed shape. +/// @param [in] const std::vector &batch_shape: Configed shape. +/// @param [out] GeShape &data_shape: GeShape for configed shape. +/// @return SUCCESS / PARAM_INVALID +/// +Status CalcShape(const std::vector &batch_shape, GeShape &data_shape); + +/// +/// @ingroup ge +/// @brief parse each data's own dynamic dims. +/// @param [out] map>> &data_to_dynamic_info: key:data_name. value:dynamic dims. +/// @return SUCCESS / PARAM_INVALID +/// +Status ParserDataToDynmaicInfo(const vector> &shapes, + map>> &data_to_dynamic_info); + +/// +/// @ingroup ge +/// @brief Set mbatch_dynamic_type on node. +/// @param [in] const OpDescPtr &op_desc: Node for set attribute. +/// @return 0: SUCCESS / others: INTERNAL_ERROR +/// +Status StampDynamicType(const OpDescPtr &op_desc); +} // namespace multibatch +} // namespace ge +#endif // GE_GRAPH_PREPROCESS_MULTI_BATCH_OPTIONS_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.cc b/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.cc deleted file mode 100644 index 32f8ec24..00000000 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "host_aicpu_engine/ops_kernel_store/op/assign_op.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/util.h" -#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" - -namespace { -const size_t kAssignInputNum = 2; -const size_t kAssignRefInputIndex = 0; -const size_t kAssignValueInputIndex = 1; -const size_t kAssignRefOutputIndex = 0; -} // namespace - -namespace ge { -namespace host_aicpu { -Status AssignOp::Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, - std::vector &outputs) { - GELOGI("AssignOp [%s, %s] compute begin.", node_.GetName().c_str(), node_.GetType().c_str()); - if (inputs.size() != kAssignInputNum) { - GELOGE(PARAM_INVALID, "Number of input for AssignOp must be %zu.", kAssignInputNum); - return PARAM_INVALID; - } - auto &ref_input = inputs[kAssignRefInputIndex]; - const auto &value_input = inputs[kAssignValueInputIndex]; - ref_input->SetData(value_input->GetData().GetData(), value_input->GetData().GetSize()); - GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(kAssignRefOutputIndex), - value_input->GetData().GetData(), value_input->GetData().GetSize()); - GE_CHECK_NOTNULL(output_ptr); - outputs.push_back(output_ptr); - GELOGI("AssignOp [%s, %s] compute success.", node_.GetName().c_str(), node_.GetType().c_str()); - return SUCCESS; -} - -REGISTER_OP_CREATOR(Assign, AssignOp); -} // namespace host_aicpu -} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.h deleted file mode 100644 index caf9d4c9..00000000 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_ASSIGN_OP_H_ -#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_ASSIGN_OP_H_ - -#include "host_aicpu_engine/ops_kernel_store/op/op.h" - -namespace ge { -namespace host_aicpu { -class AssignOp : public Op { - public: - AssignOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} - ~AssignOp() override = default; - AssignOp &operator=(const AssignOp &op) = delete; - AssignOp(const AssignOp &op) = delete; - - /** - * @brief compute for node_task. - * @return result - */ - Status Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, - std::vector &outputs) override; -}; -} // namespace host_aicpu -} // namespace ge - -#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_ASSIGN_OP_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.cc b/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.cc deleted file mode 100644 index 81768f7a..00000000 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h" -#include -#include "framework/common/debug/ge_log.h" -#include "framework/common/util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/type_utils.h" -#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" - -namespace ge { -namespace host_aicpu { -Status RandomUniformOp::Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, - std::vector &outputs) { - GELOGI("RandomUniformOp [%s, %s] compute begin.", node_.GetName().c_str(), node_.GetType().c_str()); - int64_t seed = 0; - int64_t seed2 = 0; - (void)AttrUtils::GetInt(op_desc_ptr, "seed", seed); - (void)AttrUtils::GetInt(op_desc_ptr, "seed2", seed2); - DataType data_type = DT_UNDEFINED; - if (AttrUtils::GetDataType(op_desc_ptr, VAR_ATTR_DTYPE, data_type) != GRAPH_SUCCESS) { - GELOGE(PARAM_INVALID, "get attr VAR_ATTR_DTYPE failed"); - return PARAM_INVALID; - } - - switch (data_type) { - case DT_FLOAT16: - break; - case DT_FLOAT: - if (Generate(op_desc_ptr, seed, seed2, outputs) != SUCCESS) { - GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_FLOAT"); - return FAILED; - } - break; - case DT_DOUBLE: - if (Generate(op_desc_ptr, seed, seed2, outputs) != SUCCESS) { - GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_DOUBLE"); - return FAILED; - } - break; - default: - GELOGE(UNSUPPORTED, "Supported DataType for RandomUniformOp is DT_FLOAT16 / DT_FLOAT / DT_DOUBLE, but dtype=%s", - TypeUtils::DataTypeToSerialString(data_type).c_str()); - return UNSUPPORTED; - } - - GELOGI("RandomUniformOp [%s, %s] compute success.", node_.GetName().c_str(), node_.GetType().c_str()); - return SUCCESS; -} - -template -Status RandomUniformOp::Generate(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, - std::vector &outputs) { - GE_CHECK_NOTNULL(op_desc_ptr); - // RandomUniformOp has and only has one output - int64_t data_num = op_desc_ptr->GetOutputDesc(0).GetShape().GetShapeSize(); - std::unique_ptr buf(new (std::nothrow) T[data_num]()); - if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast(sizeof(T) * data_num)); - return MEMALLOC_FAILED; - } - - int64_t final_seed; - if (seed == 0) { - if (seed2 == 0) { - std::random_device rd; - final_seed = rd(); - } else { - final_seed = seed2; - } - } else { - final_seed = seed; - } - std::mt19937_64 gen(final_seed); - std::uniform_real_distribution distribution(0, 1); - for (int64_t i = 0; i < data_num; i++) { - *(buf.get() + i) = distribution(gen); - } - - GeTensorPtr output = - MakeShared(op_desc_ptr->GetOutputDesc(0), reinterpret_cast(buf.get()), data_num * sizeof(T)); - GE_CHECK_NOTNULL(output); - outputs.emplace_back(output); - - return SUCCESS; -} - -REGISTER_OP_CREATOR(RandomUniform, RandomUniformOp); -} // namespace host_aicpu -} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h deleted file mode 100644 index dfb2485f..00000000 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_RANDOM_UNIFORM_OP_H_ -#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_RANDOM_UNIFORM_OP_H_ - -#include "host_aicpu_engine/ops_kernel_store/op/op.h" - -namespace ge { -namespace host_aicpu { -class RandomUniformOp : public Op { - public: - RandomUniformOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} - ~RandomUniformOp() override = default; - RandomUniformOp &operator=(const RandomUniformOp &op) = delete; - RandomUniformOp(const RandomUniformOp &op) = delete; - - /** - * @brief compute for node_task. - * @return result - */ - Status Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, - std::vector &outputs) override; - - private: - template - Status Generate(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, std::vector &outputs); -}; -} // namespace host_aicpu -} // namespace ge - -#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_RANDOM_UNIFORM_OP_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.cc b/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.cc deleted file mode 100644 index effa346b..00000000 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "host_aicpu_engine/ops_kernel_store/op/variable_op.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/util.h" -#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" - -namespace { -const size_t kInputSize = 1; -} - -namespace ge { -namespace host_aicpu { -Status VariableOp::Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, - std::vector &outputs) { - GELOGI("VariableOp [%s, %s] compute begin.", node_.GetName().c_str(), node_.GetType().c_str()); - if (inputs.size() != kInputSize) { - GELOGE(PARAM_INVALID, "Number of input for VariableOp must be %zu.", kInputSize); - return PARAM_INVALID; - } - GeTensorPtr output_ptr = - MakeShared(op_desc_ptr->GetOutputDesc(0), inputs[0]->GetData().GetData(), inputs[0]->GetData().GetSize()); - GE_CHECK_NOTNULL(output_ptr); - outputs.push_back(output_ptr); - GELOGI("VariableOp [%s, %s] compute success.", node_.GetName().c_str(), node_.GetType().c_str()); - return SUCCESS; -} - -REGISTER_OP_CREATOR(Variable, VariableOp); -REGISTER_OP_CREATOR(Constant, VariableOp); -} // namespace host_aicpu -} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.h deleted file mode 100644 index b6570557..00000000 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_VARIABLE_OP_H_ -#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_VARIABLE_OP_H_ - -#include "host_aicpu_engine/ops_kernel_store/op/op.h" - -namespace ge { -namespace host_aicpu { -class VariableOp : public Op { - public: - VariableOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} - ~VariableOp() override = default; - VariableOp &operator=(const VariableOp &op) = delete; - VariableOp(const VariableOp &op) = delete; - - /** - * @brief compute for node_task. - * @return result - */ - Status Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, - std::vector &outputs) override; -}; -} // namespace host_aicpu -} // namespace ge - -#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_VARIABLE_OP_H_ diff --git a/src/ge/host_aicpu_engine/common/constant/constant.h b/src/ge/host_cpu_engine/common/constant/constant.h similarity index 66% rename from src/ge/host_aicpu_engine/common/constant/constant.h rename to src/ge/host_cpu_engine/common/constant/constant.h index 998dc7eb..a3cabdc4 100644 --- a/src/ge/host_aicpu_engine/common/constant/constant.h +++ b/src/ge/host_cpu_engine/common/constant/constant.h @@ -14,17 +14,17 @@ * limitations under the License. */ -#ifndef GE_HOST_AICPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ -#define GE_HOST_AICPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ +#ifndef GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ +#define GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ #include namespace ge { -namespace host_aicpu { +namespace host_cpu { // engine name -const char kHostAiCpuEngineName[] = "DNN_VM_HOST_AICPU"; -const char kHostAiCpuOpKernelLibName[] = "DNN_VM_HOST_AICPU_OP_STORE"; -} // namespace host_aicpu +const char kHostCpuEngineName[] = "DNN_VM_HOST_CPU"; +const char kHostCpuOpKernelLibName[] = "DNN_VM_HOST_CPU_OP_STORE"; +} // namespace host_cpu } // namespace ge -#endif // GE_HOST_AICPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ +#endif // GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ diff --git a/src/ge/host_aicpu_engine/engine/host_aicpu_engine.cc b/src/ge/host_cpu_engine/engine/host_cpu_engine.cc similarity index 52% rename from src/ge/host_aicpu_engine/engine/host_aicpu_engine.cc rename to src/ge/host_cpu_engine/engine/host_cpu_engine.cc index 12ec5ede..648e13b1 100644 --- a/src/ge/host_aicpu_engine/engine/host_aicpu_engine.cc +++ b/src/ge/host_cpu_engine/engine/host_cpu_engine.cc @@ -14,61 +14,61 @@ * limitations under the License. */ -#include "host_aicpu_engine/engine/host_aicpu_engine.h" +#include "host_cpu_engine/engine/host_cpu_engine.h" #include #include #include #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" -#include "host_aicpu_engine/common/constant/constant.h" -#include "host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h" +#include "host_cpu_engine/common/constant/constant.h" +#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" namespace ge { -namespace host_aicpu { -HostAiCpuEngine &HostAiCpuEngine::Instance() { - static HostAiCpuEngine instance; +namespace host_cpu { +HostCpuEngine &HostCpuEngine::Instance() { + static HostCpuEngine instance; return instance; } -Status HostAiCpuEngine::Initialize(const std::map &options) { +Status HostCpuEngine::Initialize(const std::map &options) { if (ops_kernel_store_ == nullptr) { - ops_kernel_store_ = MakeShared(); + ops_kernel_store_ = MakeShared(); if (ops_kernel_store_ == nullptr) { - GELOGE(FAILED, "Make HostAiCpuOpsKernelInfoStore failed."); + GELOGE(FAILED, "Make HostCpuOpsKernelInfoStore failed."); return FAILED; } } return SUCCESS; } -void HostAiCpuEngine::GetOpsKernelInfoStores(std::map &ops_kernel_map) { +void HostCpuEngine::GetOpsKernelInfoStores(std::map &ops_kernel_map) { if (ops_kernel_store_ != nullptr) { // add buildin opsKernel to opsKernelInfoMap - ops_kernel_map[kHostAiCpuOpKernelLibName] = ops_kernel_store_; + ops_kernel_map[kHostCpuOpKernelLibName] = ops_kernel_store_; } } -void HostAiCpuEngine::GetGraphOptimizerObjs(std::map &) { - // no optimizer for host aicpu engine +void HostCpuEngine::GetGraphOptimizerObjs(std::map &) { + // no optimizer for host cpu engine } -Status HostAiCpuEngine::Finalize() { +Status HostCpuEngine::Finalize() { ops_kernel_store_ = nullptr; return SUCCESS; } -} // namespace host_aicpu +} // namespace host_cpu } // namespace ge ge::Status Initialize(const std::map &options) { - return ge::host_aicpu::HostAiCpuEngine::Instance().Initialize(options); + return ge::host_cpu::HostCpuEngine::Instance().Initialize(options); } void GetOpsKernelInfoStores(std::map &ops_kernel_map) { - ge::host_aicpu::HostAiCpuEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); + ge::host_cpu::HostCpuEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); } void GetGraphOptimizerObjs(std::map &graph_optimizers) { - ge::host_aicpu::HostAiCpuEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); + ge::host_cpu::HostCpuEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); } -ge::Status Finalize() { return ge::host_aicpu::HostAiCpuEngine::Instance().Finalize(); } +ge::Status Finalize() { return ge::host_cpu::HostCpuEngine::Instance().Finalize(); } diff --git a/src/ge/host_aicpu_engine/engine/host_aicpu_engine.h b/src/ge/host_cpu_engine/engine/host_cpu_engine.h similarity index 71% rename from src/ge/host_aicpu_engine/engine/host_aicpu_engine.h rename to src/ge/host_cpu_engine/engine/host_cpu_engine.h index f8ad71b1..ecafd98b 100644 --- a/src/ge/host_aicpu_engine/engine/host_aicpu_engine.h +++ b/src/ge/host_cpu_engine/engine/host_cpu_engine.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef GE_HOST_AICPU_ENGINE_ENGINE_HOST_AICPU_ENGINE_H_ -#define GE_HOST_AICPU_ENGINE_ENGINE_HOST_AICPU_ENGINE_H_ +#ifndef GE_HOST_CPU_ENGINE_ENGINE_HOST_CPU_ENGINE_H_ +#define GE_HOST_CPU_ENGINE_ENGINE_HOST_CPU_ENGINE_H_ #include #include @@ -27,20 +27,20 @@ using OpsKernelInfoStorePtr = std::shared_ptr; using GraphOptimizerPtr = std::shared_ptr; namespace ge { -namespace host_aicpu { +namespace host_cpu { /** - * host aicpu engine. + * host cpu engine. * Used for the ops which executes on host. */ -class HostAiCpuEngine { +class HostCpuEngine { public: /** - * get HostAiCpuEngine instance. - * @return HostAiCpuEngine instance. + * get HostCpuEngine instance. + * @return HostCpuEngine instance. */ - static HostAiCpuEngine &Instance(); + static HostCpuEngine &Instance(); - virtual ~HostAiCpuEngine() = default; + virtual ~HostCpuEngine() = default; /** * When Ge start, GE will invoke this interface @@ -51,14 +51,14 @@ class HostAiCpuEngine { /** * After the initialize, GE will invoke this interface * to get the Ops kernel Store. - * @param ops_kernel_map The host aicpu's ops kernel info + * @param ops_kernel_map The host cpu's ops kernel info */ void GetOpsKernelInfoStores(std::map &ops_kernel_map); /** * After the initialize, GE will invoke this interface * to get the Graph Optimizer. - * @param graph_optimizers The host aicpu's Graph Optimizer objs + * @param graph_optimizers The host cpu's Graph Optimizer objs */ void GetGraphOptimizerObjs(std::map &graph_optimizers); @@ -68,17 +68,17 @@ class HostAiCpuEngine { */ Status Finalize(); - HostAiCpuEngine(const HostAiCpuEngine &HostAiCpuEngine) = delete; - HostAiCpuEngine(const HostAiCpuEngine &&HostAiCpuEngine) = delete; - HostAiCpuEngine &operator=(const HostAiCpuEngine &HostAiCpuEngine) = delete; - HostAiCpuEngine &operator=(HostAiCpuEngine &&HostAiCpuEngine) = delete; + HostCpuEngine(const HostCpuEngine &HostCpuEngine) = delete; + HostCpuEngine(const HostCpuEngine &&HostCpuEngine) = delete; + HostCpuEngine &operator=(const HostCpuEngine &HostCpuEngine) = delete; + HostCpuEngine &operator=(HostCpuEngine &&HostCpuEngine) = delete; private: - HostAiCpuEngine() = default; + HostCpuEngine() = default; OpsKernelInfoStorePtr ops_kernel_store_ = nullptr; }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace ge extern "C" { @@ -91,13 +91,13 @@ ge::Status Initialize(const map &options); /** * After the initialize, GE will invoke this interface to get the Ops kernel Store - * @param ops_kernel_map The host aicpu's ops kernel info + * @param ops_kernel_map The host cpu's ops kernel info */ void GetOpsKernelInfoStores(std::map &ops_kernel_map); /** * After the initialize, GE will invoke this interface to get the Graph Optimizer - * @param graph_optimizers The host aicpu's Graph Optimizer objs + * @param graph_optimizers The host cpu's Graph Optimizer objs */ void GetGraphOptimizerObjs(std::map &graph_optimizers); @@ -108,4 +108,4 @@ void GetGraphOptimizerObjs(std::map &graph_optim ge::Status Finalize(); } -#endif // GE_HOST_AICPU_ENGINE_ENGINE_HOST_AICPU_ENGINE_H_ +#endif // GE_HOST_CPU_ENGINE_ENGINE_HOST_CPU_ENGINE_H_ diff --git a/src/ge/host_aicpu_engine/module.mk b/src/ge/host_cpu_engine/module.mk similarity index 88% rename from src/ge/host_aicpu_engine/module.mk rename to src/ge/host_cpu_engine/module.mk index 48dd6a87..41de4503 100644 --- a/src/ge/host_aicpu_engine/module.mk +++ b/src/ge/host_cpu_engine/module.mk @@ -1,8 +1,8 @@ LOCAL_PATH := $(call my-dir) -local_lib_src_files := engine/host_aicpu_engine.cc \ - ops_kernel_store/host_aicpu_ops_kernel_info.cc \ +local_lib_src_files := engine/host_cpu_engine.cc \ + ops_kernel_store/host_cpu_ops_kernel_info.cc \ ops_kernel_store/op/op_factory.cc \ ops_kernel_store/op/host_op.cc \ @@ -18,7 +18,7 @@ local_lib_inc_path := proto/task.proto \ #compiler for host include $(CLEAR_VARS) -LOCAL_MODULE := libhost_aicpu_engine +LOCAL_MODULE := libhost_cpu_engine LOCAL_CFLAGS += -Werror LOCAL_CFLAGS += -std=c++11 LOCAL_LDFLAGS := @@ -38,7 +38,7 @@ include ${BUILD_HOST_SHARED_LIBRARY} #compiler for atc include $(CLEAR_VARS) -LOCAL_MODULE := atclib/libhost_aicpu_engine +LOCAL_MODULE := atclib/libhost_cpu_engine LOCAL_CFLAGS += -Werror LOCAL_CFLAGS += -std=c++11 LOCAL_LDFLAGS := diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.cc b/src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc similarity index 82% rename from src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.cc rename to src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc index 4dbedab1..4e7be2d5 100644 --- a/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.cc +++ b/src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h" +#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" #include #include "common/constant/constant.h" #include "ge/ge_api_types.h" @@ -28,16 +28,16 @@ #include "proto/task.pb.h" namespace ge { -namespace host_aicpu { +namespace host_cpu { using domi::TaskDef; using std::map; using std::string; using std::vector; -Status HostAiCpuOpsKernelInfoStore::Initialize(const map &options) { - GELOGI("HostAiCpuOpsKernelInfoStore init start."); - OpInfo default_op_info = {.engine = kHostAiCpuEngineName, - .opKernelLib = kHostAiCpuOpKernelLibName, +Status HostCpuOpsKernelInfoStore::Initialize(const map &options) { + GELOGI("HostCpuOpsKernelInfoStore init start."); + OpInfo default_op_info = {.engine = kHostCpuEngineName, + .opKernelLib = kHostCpuOpKernelLibName, .computeCost = 0, .flagPartial = false, .flagAsync = false, @@ -48,17 +48,17 @@ Status HostAiCpuOpsKernelInfoStore::Initialize(const map &option op_info_map_[op] = default_op_info; } - GELOGI("HostAiCpuOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); + GELOGI("HostCpuOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); return SUCCESS; } -Status HostAiCpuOpsKernelInfoStore::Finalize() { +Status HostCpuOpsKernelInfoStore::Finalize() { op_info_map_.clear(); return SUCCESS; } -Status HostAiCpuOpsKernelInfoStore::CalcOpRunningParam(Node &ge_node) { +Status HostCpuOpsKernelInfoStore::CalcOpRunningParam(Node &ge_node) { OpDescPtr op_desc = ge_node.GetOpDesc(); if (op_desc == nullptr) { GELOGE(FAILED, "CalcOpRunningParam failed, as op desc is null"); @@ -111,22 +111,23 @@ Status HostAiCpuOpsKernelInfoStore::CalcOpRunningParam(Node &ge_node) { return FAILED; } } + GELOGD("Calc op[%s:%s] running param success.", name.c_str(), type.c_str()); return SUCCESS; } -void HostAiCpuOpsKernelInfoStore::GetAllOpsKernelInfo(map &infos) const { infos = op_info_map_; } +void HostCpuOpsKernelInfoStore::GetAllOpsKernelInfo(map &infos) const { infos = op_info_map_; } -Status HostAiCpuOpsKernelInfoStore::GenerateTask(const Node &node, RunContext &context, vector &tasks) { +Status HostCpuOpsKernelInfoStore::GenerateTask(const Node &node, RunContext &context, vector &tasks) { // no need to generate device task return SUCCESS; } -bool HostAiCpuOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { +bool HostCpuOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { if (op_desc == nullptr) { return false; } return op_info_map_.count(op_desc->GetType()) > 0; } -} // namespace host_aicpu +} // namespace host_cpu } // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h b/src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h similarity index 70% rename from src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h rename to src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h index a4051b9b..1202cc8a 100644 --- a/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h +++ b/src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_HOST_AICPU_OPS_KERNEL_INFO_H_ -#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_HOST_AICPU_OPS_KERNEL_INFO_H_ +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ #include #include @@ -24,20 +24,20 @@ #include "common/opskernel/ops_kernel_info_store.h" namespace ge { -namespace host_aicpu { -class HostAiCpuOpsKernelInfoStore : public OpsKernelInfoStore { +namespace host_cpu { +class HostCpuOpsKernelInfoStore : public OpsKernelInfoStore { public: - HostAiCpuOpsKernelInfoStore() {} - ~HostAiCpuOpsKernelInfoStore() override = default; + HostCpuOpsKernelInfoStore() {} + ~HostCpuOpsKernelInfoStore() override = default; /** - * Initialize related resources of the host aicpu kernelinfo store + * Initialize related resources of the host cpu kernelinfo store * @return status whether this operation success */ Status Initialize(const std::map &options) override; /** - * Release related resources of the host aicpu kernel info store + * Release related resources of the host cpu kernel info store * @return status whether this operation success */ Status Finalize() override; @@ -73,16 +73,16 @@ class HostAiCpuOpsKernelInfoStore : public OpsKernelInfoStore { */ Status GenerateTask(const ge::Node &ge_node, ge::RunContext &context, std::vector &tasks) override; - HostAiCpuOpsKernelInfoStore(const HostAiCpuOpsKernelInfoStore &ops_kernel_store) = delete; - HostAiCpuOpsKernelInfoStore(const HostAiCpuOpsKernelInfoStore &&ops_kernel_store) = delete; - HostAiCpuOpsKernelInfoStore &operator=(const HostAiCpuOpsKernelInfoStore &ops_kernel_store) = delete; - HostAiCpuOpsKernelInfoStore &operator=(HostAiCpuOpsKernelInfoStore &&ops_kernel_store) = delete; + HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; + HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; + HostCpuOpsKernelInfoStore &operator=(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; + HostCpuOpsKernelInfoStore &operator=(HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; private: // store op name and OpInfo key-value pair std::map op_info_map_; }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace ge -#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_HOST_AICPU_OPS_KERNEL_INFO_H_ +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.cc b/src/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc similarity index 80% rename from src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.cc rename to src/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc index 9dbd80e0..472fca45 100644 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.cc +++ b/src/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "host_aicpu_engine/ops_kernel_store/op/host_op.h" +#include "host_cpu_engine/ops_kernel_store/op/host_op.h" #include "framework/common/util.h" -#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" +#include "host_cpu_engine/ops_kernel_store/op/op_factory.h" namespace ge { -namespace host_aicpu { +namespace host_cpu { Status HostOp::Run() { // no need to generate device task return SUCCESS; @@ -30,5 +30,7 @@ REGISTER_OP_CREATOR(Variable, HostOp); REGISTER_OP_CREATOR(Constant, HostOp); REGISTER_OP_CREATOR(Assign, HostOp); REGISTER_OP_CREATOR(RandomUniform, HostOp); -} // namespace host_aicpu +REGISTER_OP_CREATOR(Add, HostOp); +REGISTER_OP_CREATOR(Mul, HostOp); +} // namespace host_cpu } // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.h b/src/ge/host_cpu_engine/ops_kernel_store/op/host_op.h similarity index 76% rename from src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.h rename to src/ge/host_cpu_engine/ops_kernel_store/op/host_op.h index 6655e620..757b96a6 100644 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.h +++ b/src/ge/host_cpu_engine/ops_kernel_store/op/host_op.h @@ -14,13 +14,13 @@ * limitations under the License. */ -#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ -#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ -#include "host_aicpu_engine/ops_kernel_store/op/op.h" +#include "host_cpu_engine/ops_kernel_store/op/op.h" namespace ge { -namespace host_aicpu { +namespace host_cpu { class HostOp : public Op { public: HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} @@ -30,7 +30,7 @@ class HostOp : public Op { Status Run() override; }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace ge -#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/op.h b/src/ge/host_cpu_engine/ops_kernel_store/op/op.h similarity index 83% rename from src/ge/host_aicpu_engine/ops_kernel_store/op/op.h rename to src/ge/host_cpu_engine/ops_kernel_store/op/op.h index 87c7993e..c1e1619c 100644 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/op.h +++ b/src/ge/host_cpu_engine/ops_kernel_store/op/op.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ -#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ #include #include @@ -25,7 +25,7 @@ #include "graph/node.h" namespace ge { -namespace host_aicpu { +namespace host_cpu { /** * The base class for all op. */ @@ -39,7 +39,7 @@ class Op { const RunContext &run_context_; const Node &node_; }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace ge -#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.cc b/src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc similarity index 93% rename from src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.cc rename to src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc index ec376d8a..efe44f80 100644 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.cc +++ b/src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" +#include "host_cpu_engine/ops_kernel_store/op/op_factory.h" #include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "graph/op_desc.h" namespace ge { -namespace host_aicpu { +namespace host_cpu { OpFactory &OpFactory::Instance() { static OpFactory instance; return instance; @@ -51,5 +51,5 @@ void OpFactory::RegisterCreator(const std::string &type, const OP_CREATOR_FUNC & op_creator_map_[type] = func; all_ops_.emplace_back(type); } -} // namespace host_aicpu +} // namespace host_cpu } // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.h b/src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h similarity index 90% rename from src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.h rename to src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h index 007bceaa..92f627fd 100644 --- a/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.h +++ b/src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ -#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ #include #include @@ -23,10 +23,10 @@ #include #include #include "common/ge/ge_util.h" -#include "host_aicpu_engine/ops_kernel_store/op/op.h" +#include "host_cpu_engine/ops_kernel_store/op/op.h" namespace ge { -namespace host_aicpu { +namespace host_cpu { using OP_CREATOR_FUNC = std::function(const Node &, RunContext &)>; /** @@ -88,7 +88,7 @@ class OpRegistrar { return MakeShared(node, run_context); \ } \ OpRegistrar g_##type##Op_creator(#type, Creator_##type##Op) -} // namespace host_aicpu +} // namespace host_cpu } // namespace ge -#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ diff --git a/src/ge/host_cpu_engine/proto/task.proto b/src/ge/host_cpu_engine/proto/task.proto new file mode 120000 index 00000000..36ae4847 --- /dev/null +++ b/src/ge/host_cpu_engine/proto/task.proto @@ -0,0 +1 @@ +../../proto/task.proto \ No newline at end of file diff --git a/src/ge/host_kernels/concat_offset_kernel.cc b/src/ge/host_kernels/concat_offset_kernel.cc index 0a870949..fe06534f 100644 --- a/src/ge/host_kernels/concat_offset_kernel.cc +++ b/src/ge/host_kernels/concat_offset_kernel.cc @@ -41,7 +41,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vector(reinterpret_cast(input_0->GetData().data()))); // validate inputs - if (static_cast(input.size()) != (N + kNumOne) || input.size() <= kConcatOffsetInputIndexOne) { - GELOGW("The number of input for concat offset must be equal with %d, and must be more than one.", (N + kNumOne)); + if ((static_cast(input.size()) != (N + kNumOne)) || (input.size() <= kConcatOffsetInputIndexOne)) { + GELOGW("The number of input for concat offset must be equal to %d, and must be more than one.", (N + kNumOne)); return NOT_CHANGED; } @@ -58,7 +58,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vectorGetTensorDesc().GetShape(); int64_t output_size = output_shape.GetShapeSize(); if (concat_dim >= output_size) { - GELOGW("Concat dim is biger than the size of output_shape."); + GELOGW("Concat dim is bigger than the size of output_shape."); return NOT_CHANGED; } GELOGI("Output shape size is %ld", output_size); @@ -93,7 +93,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vector(input[i + kConcatOffsetInputIndexOne]->GetData().data()); int64_t input_dim = input_shape[concat_dim]; // this index is valid, checked before - if (input_dim > (INT32_MAX - offset)) { + if (input_dim > (INT64_MAX - offset)) { GELOGE(PARAM_INVALID, " %d and %ld addition can result in overflow!.", offset, input_dim); return INTERNAL_ERROR; } diff --git a/src/ge/host_kernels/floordiv_kernel.cc b/src/ge/host_kernels/floordiv_kernel.cc index 05eded80..5114122c 100644 --- a/src/ge/host_kernels/floordiv_kernel.cc +++ b/src/ge/host_kernels/floordiv_kernel.cc @@ -146,7 +146,7 @@ Status FloorDivKernel::DataCalBroadcast(const T &x, const T &y, size_t num_x, si if (num_x > num_y) { if (ZeroCheck(y, data_type)) { - GELOGE(PARAM_INVALID, "The divisor of FloorDiv con not be zero"); + GELOGE(PARAM_INVALID, "The divisor of FloorDiv can not be zero."); return PARAM_INVALID; } for (size_t i = 0; i < num_x; ++i) { @@ -155,7 +155,7 @@ Status FloorDivKernel::DataCalBroadcast(const T &x, const T &y, size_t num_x, si } else { for (size_t i = 0; i < num_y; ++i) { if (ZeroCheck((&y)[i], data_type)) { - GELOGE(PARAM_INVALID, "The divisor of FloorDiv con not be zero"); + GELOGE(PARAM_INVALID, "The divisor of FloorDiv can not be zero."); return PARAM_INVALID; } buf[i] = DivCal(x, (&y)[i]); @@ -195,7 +195,7 @@ Status FloorDivKernel::DataCal(const std::vector &input, GeTen for (size_t i = 0; i < data_num_x; ++i) { if (ZeroCheck(y[i], data_type)) { - GELOGE(PARAM_INVALID, "The divisor of FloorDiv con not be zero"); + GELOGE(PARAM_INVALID, "The divisor of FloorDiv can not be zero."); return PARAM_INVALID; } buf[i] = DivCal(x[i], y[i]); diff --git a/src/ge/host_kernels/reduce_prod_kernel.cc b/src/ge/host_kernels/reduce_prod_kernel.cc index 739d4b9f..0a3fad72 100644 --- a/src/ge/host_kernels/reduce_prod_kernel.cc +++ b/src/ge/host_kernels/reduce_prod_kernel.cc @@ -60,7 +60,7 @@ Status ReduceProdKernel::ReduceProdCheck(const ge::OpDescPtr &op_desc_ptr, GE_CHECK_NOTNULL(data_tensor); GE_CHECK_NOTNULL(axis_tensor); if (axis_tensor->GetTensorDesc().GetShape().GetDimNum() > kReduceProdMaxAxisRank) { - GELOGW("Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); + GELOGW("Axis must be at most rank 1, node: %s", op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } diff --git a/src/ge/host_kernels/rsqrt_kernel.cc b/src/ge/host_kernels/rsqrt_kernel.cc index 56972d23..f91e3399 100644 --- a/src/ge/host_kernels/rsqrt_kernel.cc +++ b/src/ge/host_kernels/rsqrt_kernel.cc @@ -27,71 +27,124 @@ #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" +#include "common/math/math_util.h" namespace ge { namespace { const size_t kRsqrtInputSize = 1; const size_t kRsqrtInputIndex0 = 0; + +template +Status ZeroCheck(T x, const DataType &data_type) { + switch (data_type) { + case DT_FLOAT16: + FMK_FP16_ZEROCHECK(static_cast(x)) + break; + case DT_FLOAT: + FMK_FLOAT_ZEROCHECK(static_cast(x)) + break; + case DT_DOUBLE: + FMK_DOUBLE_ZEROCHECK(static_cast(x)) + break; + default: + break; + } + return SUCCESS; +} +#define SET_RSQRT_CASE(DTYPE, TYPE) \ + case (DTYPE): \ + ret = RsqrtKernel::RsqrtCompute(input_ptr, output_ptr); \ + break; } // namespace + +template +Status RsqrtKernel::RsqrtCompute(ConstGeTensorPtr &input_tensor_ptr, GeTensorPtr &output_tensor_ptr) { + GE_CHECK_NOTNULL(input_tensor_ptr); + GE_CHECK_NOTNULL(output_tensor_ptr); + size_t data_size = input_tensor_ptr->GetData().size(); + size_t data_count = data_size / sizeof(T); + auto data_type = input_tensor_ptr->GetTensorDesc().GetDataType(); + if (data_count > 0) { + unique_ptr buf(new (std::nothrow) T[data_count]()); + if (buf == nullptr) { + GELOGW("New buf failed"); + return NOT_CHANGED; + } + auto ptr = const_cast(reinterpret_cast(input_tensor_ptr->GetData().data())); + for (size_t i = 0; i < data_count; i++) { + if (ZeroCheck(*(ptr + i), data_type) != SUCCESS) { + GELOGE(PARAM_INVALID, "The input data can not be 0. "); + return PARAM_INVALID; + } + switch (data_type) { + case DT_FLOAT16: { + double val = static_cast(*(reinterpret_cast(input_tensor_ptr->GetData().data()) + i)); + GE_IF_BOOL_EXEC(val < 0, GELOGE(PARAM_INVALID, "The denominator data %lf can not less than 0.", val); + return PARAM_INVALID); + double drSqrt = 1.0 / std::sqrt(val); + buf[i] = drSqrt; + break; + } + case DT_FLOAT: { + float denominator = std::sqrt(*(reinterpret_cast(input_tensor_ptr->GetData().data()) + i)); + buf[i] = static_cast(1 / denominator); + break; + } + case DT_DOUBLE: { + double denominator = std::sqrt(*(reinterpret_cast(input_tensor_ptr->GetData().data()) + i)); + buf[i] = static_cast(1 / denominator); + break; + } + default: + GELOGW("Input data type must be FP16, FP32 and DOUBLE."); + return NOT_CHANGED; + } + } + GE_IF_BOOL_EXEC(output_tensor_ptr->SetData(reinterpret_cast(buf.get()), data_size) != GRAPH_SUCCESS, + GELOGW("Set data failed"); + return NOT_CHANGED); + output_tensor_ptr->MutableTensorDesc().SetDataType(data_type); + output_tensor_ptr->MutableTensorDesc().SetShape(input_tensor_ptr->GetTensorDesc().GetShape()); + } + return SUCCESS; +} + Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector &input, std::vector &v_output) { GELOGI("RsqrtKernel in."); GE_CHECK_NOTNULL(op_desc_ptr); + // check input size if (input.size() != kRsqrtInputSize) { GELOGW("The number of input for rsqrt must be %zu.", kRsqrtInputSize); - return PARAM_INVALID; - } - - ConstGeTensorPtr input_ = input.at(kRsqrtInputIndex0); - GE_CHECK_NOTNULL(input_); - if (input_->GetTensorDesc().GetDataType() != DT_FLOAT) { - GELOGW("input data type must be FP32."); return NOT_CHANGED; } - const GeShape &x_shape = input_->GetTensorDesc().GetShape(); - size_t data_size = input_->GetData().size(); - size_t data_count = data_size / sizeof(float); + ConstGeTensorPtr input_ptr = input.at(kRsqrtInputIndex0); + GE_CHECK_NOTNULL(input_ptr); - // check whether input is zero - for (size_t i = 0; i < data_count; i++) { - if (fabs(*(reinterpret_cast(input_->GetData().data()) + i)) < FLT_EPSILON) { - GELOGW("input must be not equal 0."); - return NOT_CHANGED; - } + // Index 0 can always gets a GeTensorDesc object from any OpDescPtr. + auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); + GeTensorPtr output_ptr = MakeShared(output_tensor_desc); + if (output_ptr == nullptr) { + GELOGW("MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + return NOT_CHANGED; } - if (data_count > 0) { - unique_ptr buf(new (std::nothrow) float[data_count]()); - if (buf == nullptr) { - GELOGW("new buf failed"); + Status ret = NOT_CHANGED; + auto dtype = input_ptr->GetTensorDesc().GetDataType(); + switch (dtype) { + SET_RSQRT_CASE(DT_FLOAT16, fp16_t) + SET_RSQRT_CASE(DT_FLOAT, float) + SET_RSQRT_CASE(DT_DOUBLE, double) + default: + GELOGW("Input data type must be FP16, FP32 and DOUBLE."); return NOT_CHANGED; - } - - for (size_t i = 0; i < data_count; i++) { - float denominator = sqrt(*(reinterpret_cast(input_->GetData().data()) + i)); - if (fabs(denominator) < FLT_EPSILON) { - GELOGW("input must be not equal 0."); - return NOT_CHANGED; - } - buf[i] = 1 / denominator; - } - - // Index 0 can always gets a GeTensorDesc object from any OpDescPtr. - auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); - GeTensorPtr output_ptr = MakeShared(output_tensor_desc); - if (output_ptr == nullptr) { - GELOGW("MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); - return NOT_CHANGED; - } - - output_ptr->MutableTensorDesc().SetDataType(DT_FLOAT); - GE_IF_BOOL_EXEC(output_ptr->SetData(reinterpret_cast(buf.get()), data_size) != GRAPH_SUCCESS, - GELOGW("set data failed"); - return NOT_CHANGED); - output_ptr->MutableTensorDesc().SetShape(x_shape); - v_output.push_back(output_ptr); } + if (ret != SUCCESS) { + GELOGW("Rsqrt folding failed."); + return NOT_CHANGED; + } + v_output.push_back(output_ptr); GELOGI("RsqrtKernel success."); return SUCCESS; } diff --git a/src/ge/host_kernels/rsqrt_kernel.h b/src/ge/host_kernels/rsqrt_kernel.h index f0bf9d7e..02b08252 100644 --- a/src/ge/host_kernels/rsqrt_kernel.h +++ b/src/ge/host_kernels/rsqrt_kernel.h @@ -20,12 +20,17 @@ #include #include "inc/kernel.h" +#include "common/fp16_t.h" namespace ge { class RsqrtKernel : public Kernel { public: Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input, std::vector &v_output) override; + + private: + template + Status RsqrtCompute(ConstGeTensorPtr &input_tensor_ptr, GeTensorPtr &output_tensor_ptr); }; } // namespace ge diff --git a/src/ge/hybrid/common/npu_memory_allocator.h b/src/ge/hybrid/common/npu_memory_allocator.h index a9744540..7aa15578 100644 --- a/src/ge/hybrid/common/npu_memory_allocator.h +++ b/src/ge/hybrid/common/npu_memory_allocator.h @@ -23,20 +23,25 @@ #include #include #include "external/ge/ge_api_error_codes.h" +#include "memory/memory_api.h" namespace ge { namespace hybrid { class AllocationAttr { public: + AllocationAttr() = default; explicit AllocationAttr(int padding); explicit AllocationAttr(void *try_reuse_addr); AllocationAttr(int padding, void *try_reuse_addr); ~AllocationAttr() = default; + void SetMemType(MemStorageType memType) { mem_type_ = memType; } + MemStorageType GetMemType() { return mem_type_; } private: friend class NpuMemoryAllocator; int padding_ = 0; void *try_reuse_addr_ = nullptr; + MemStorageType mem_type_ = HBM; }; class NpuMemoryAllocator { diff --git a/src/ge/hybrid/executor/hybrid_execution_context.cc b/src/ge/hybrid/executor/hybrid_execution_context.cc index 8144ba52..491220be 100644 --- a/src/ge/hybrid/executor/hybrid_execution_context.cc +++ b/src/ge/hybrid/executor/hybrid_execution_context.cc @@ -17,5 +17,15 @@ #include "hybrid_execution_context.h" namespace ge { -namespace hybrid {} // namespace hybrid +namespace hybrid { +void GraphExecutionContext::SetErrorCode(Status error_code) { + std::lock_guard lk(mu); + this->status = error_code; +} + +Status GraphExecutionContext::GetStatus() const { + std::lock_guard lk(mu); + return this->status; +} +} // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/hybrid_execution_context.h b/src/ge/hybrid/executor/hybrid_execution_context.h index 96722fa9..37822039 100644 --- a/src/ge/hybrid/executor/hybrid_execution_context.h +++ b/src/ge/hybrid/executor/hybrid_execution_context.h @@ -20,6 +20,7 @@ #include #include #include "common/blocking_queue.h" +#include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/common/tensor_value.h" @@ -32,6 +33,9 @@ namespace ge { namespace hybrid { struct GraphExecutionContext { + void SetErrorCode(Status error_code); + Status GetStatus() const; + uint64_t session_id = 0; const HybridModel *model = nullptr; rtStream_t stream = nullptr; @@ -40,15 +44,18 @@ struct GraphExecutionContext { std::unique_ptr callback_manager; NpuMemoryAllocator *allocator = nullptr; mutable std::unique_ptr profiler; + DumpProperties dump_properties; bool trace_enabled = false; - long profiling_level = 0; bool dump_enabled = false; + long profiling_level = 0; long iteration = 0; + Status status = SUCCESS; + mutable std::mutex mu; }; #define RECORD_PROFILING_EVENT(context, evt_type, fmt, category, node_name, ...) \ do { \ - if ((context)->profiler != nullptr) { \ + if ((context != nullptr) && (context)->profiler != nullptr) { \ if (node_name != nullptr) { \ context->profiler->RecordEvent(evt_type, "tid:%lu [%s] [%s] " fmt, GetTid(), node_name, category, \ ##__VA_ARGS__); \ diff --git a/src/ge/hybrid/executor/hybrid_model_executor.cc b/src/ge/hybrid/executor/hybrid_model_executor.cc index d62d7be3..6fe23dee 100644 --- a/src/ge/hybrid/executor/hybrid_model_executor.cc +++ b/src/ge/hybrid/executor/hybrid_model_executor.cc @@ -20,6 +20,10 @@ namespace ge { namespace hybrid { +namespace { +const int kIntBase = 10; +const char *const kEnvProfilingLevel = "HYBRID_PROFILING_LEVEL"; +} // namespace HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream) : model_(model), device_id_(device_id), stream_(stream) {} @@ -43,6 +47,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { GELOGD("Model executed successfully."); if (context_.profiler != nullptr) { + context_.profiler->Dump(std::cout); context_.profiler->Reset(); } @@ -87,6 +92,17 @@ Status HybridModelExecutor::InitExecutionContext() { GE_CHECK_NOTNULL(context_.allocator); context_.callback_manager = std::unique_ptr(new (std::nothrow) CallbackManager(stream_)); GE_CHECK_NOTNULL(context_.callback_manager); + context_.dump_properties = PropertiesManager::Instance().GetDumpProperties(context_.session_id); + const char *profiling_level = std::getenv(kEnvProfilingLevel); + if (profiling_level != nullptr) { + context_.profiling_level = std::strtol(profiling_level, nullptr, kIntBase); + GELOGD("Got profiling level = %d", context_.profiling_level); + if (context_.profiling_level > 0) { + context_.profiler.reset(new (std::nothrow) HybridProfiler()); + GE_CHECK_NOTNULL(context_.profiler); + } + } + if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { context_.trace_enabled = true; } diff --git a/src/ge/hybrid/executor/hybrid_profiler.cc b/src/ge/hybrid/executor/hybrid_profiler.cc index 4c70e043..0150934e 100644 --- a/src/ge/hybrid/executor/hybrid_profiler.cc +++ b/src/ge/hybrid/executor/hybrid_profiler.cc @@ -57,6 +57,7 @@ void HybridProfiler::Dump(std::ostream &output_stream) { return; } + auto start_dump = std::chrono::system_clock::now(); auto first_evt = events_[0]; auto start = first_evt.timestamp; std::vector prev_timestamps; @@ -70,7 +71,11 @@ void HybridProfiler::Dump(std::ostream &output_stream) { prev_ts = evt.timestamp; output_stream << std::setw(kIndent) << elapsed << "\t\t" << cost << "\t\t" << evt.desc << std::endl; } - + auto end_dump = std::chrono::system_clock::now(); + auto elapsed_dump = std::chrono::duration_cast(end_dump - start).count(); + auto cost_dump = std::chrono::duration_cast(end_dump - start_dump).count(); + output_stream << std::setw(kIndent) << elapsed_dump << "\t\t" << cost_dump << "\t\t" + << "[Dump profiling]" << std::endl; events_.clear(); } diff --git a/src/ge/hybrid/executor/node_done_manager.cc b/src/ge/hybrid/executor/node_done_manager.cc index 3ec45339..de4ea14e 100644 --- a/src/ge/hybrid/executor/node_done_manager.cc +++ b/src/ge/hybrid/executor/node_done_manager.cc @@ -21,7 +21,7 @@ namespace ge { namespace hybrid { namespace { -constexpr int kDefaultWaitTimeoutInSec = 10; +constexpr int kDefaultWaitTimeoutInSec = 60 * 10; } bool NodeDoneManager::Cond::Await() { std::unique_lock lk(cond_mu_); diff --git a/src/ge/hybrid/executor/node_state.cc b/src/ge/hybrid/executor/node_state.cc index c78dd725..e8e94c0d 100644 --- a/src/ge/hybrid/executor/node_state.cc +++ b/src/ge/hybrid/executor/node_state.cc @@ -24,8 +24,10 @@ namespace ge { namespace hybrid { namespace { -constexpr auto kMaxWaitTimeInSec = 10; -} +// 5s * 120, wait for 10m +constexpr auto kWaitInternal = 5; +constexpr auto kMaxWaitTimes = 120; +} // namespace ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(node_item) { this->num_pending_shapes_ = node_item.num_inputs - node_item.num_static_input_shapes; GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", node_item.NodeName().c_str(), @@ -72,11 +74,25 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex std::unique_lock lk(mu_); if (num_pending_shapes_ > 0) { GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); - if (!ready_cv_.wait_for(lk, std::chrono::seconds(kMaxWaitTimeInSec), [&]() { return num_pending_shapes_ == 0; })) { - GELOGE(INTERNAL_ERROR, "[%s] Wait for shape timeout.", node_item.NodeName().c_str()); - return INTERNAL_ERROR; + int try_count = 0; + bool wait_success = false; + while (try_count++ < kMaxWaitTimes) { + if (ready_cv_.wait_for(lk, std::chrono::seconds(kWaitInternal), [&]() { return num_pending_shapes_ == 0; })) { + GELOGD("[%s] Await pending shape or shape future end.", node_item.NodeName().c_str()); + wait_success = true; + break; + } + + if (context.GetStatus() != SUCCESS) { + GELOGE(FAILED, "[%s] Await pending shape cancelled", node_item.NodeName().c_str()); + break; + } + } + + if (!wait_success) { + GELOGE(FAILED, "[%s] Wait for shape timeout.", node_item.NodeName().c_str()); + return FAILED; } - GELOGD("[%s] Await pending shape or shape future end.", node_item.NodeName().c_str()); } for (auto &p : shape_futures) { diff --git a/src/ge/hybrid/executor/subgraph_context.cc b/src/ge/hybrid/executor/subgraph_context.cc index 395e75de..5d94efa2 100644 --- a/src/ge/hybrid/executor/subgraph_context.cc +++ b/src/ge/hybrid/executor/subgraph_context.cc @@ -26,8 +26,8 @@ Status SubgraphContext::Init() { GE_CHECK_NOTNULL(graph_item_); GELOGD("[%s] Start to init subgraph context. total inputs = %d, total outputs = %d", graph_item_->GetName().c_str(), graph_item_->TotalInputs(), graph_item_->TotalOutputs()); - all_inputs_.resize(graph_item_->TotalInputs()); - all_outputs_.resize(graph_item_->TotalOutputs()); + all_inputs_.resize(static_cast(graph_item_->TotalInputs())); + all_outputs_.resize(static_cast(graph_item_->TotalOutputs())); return SUCCESS; } @@ -48,7 +48,6 @@ Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { index); return INTERNAL_ERROR; } - all_inputs_[index] = tensor; return SUCCESS; } @@ -60,7 +59,7 @@ Status SubgraphContext::SetInput(const NodeItem &node_item, int input_index, con Status SubgraphContext::SetOutput(const NodeItem &node_item, int output_index, const TensorValue &tensor) { auto index = node_item.output_start + output_index; - if (output_index >= node_item.num_outputs || static_cast(index) >= all_outputs_.size()) { + if ((output_index >= node_item.num_outputs) || (static_cast(index) >= all_outputs_.size())) { GELOGE(INTERNAL_ERROR, "output index output range. all output num = %zu, node_item = %s, output index = %d", all_outputs_.size(), node_item.DebugString().c_str(), output_index); return INTERNAL_ERROR; diff --git a/src/ge/hybrid/executor/subgraph_executor.cc b/src/ge/hybrid/executor/subgraph_executor.cc index 7664e90d..c76bb209 100644 --- a/src/ge/hybrid/executor/subgraph_executor.cc +++ b/src/ge/hybrid/executor/subgraph_executor.cc @@ -106,7 +106,7 @@ Status SubgraphExecutor::InitInputsForKnownShape(const std::vector } auto &input_tensor = inputs[parent_input_index]; - subgraph_context_->SetInput(i, input_tensor); + subgraph_context_->SetInput(static_cast(i), input_tensor); GELOGD("[%s] Set input tensor[%zu] with inputs with index = %d, tensor = %s", graph_item_->GetName().c_str(), i, parent_input_index, input_tensor.DebugString().c_str()); } @@ -175,8 +175,8 @@ Status SubgraphExecutor::PrepareNodes() { GELOGD("[%s] Start to prepare nodes. force infer shape = %s.", graph_item_->GetName().c_str(), force_infer_shape_ ? "true" : "false"); auto &all_nodes = graph_item_->GetAllNodes(); - for (size_t i = 0; i < all_nodes.size(); ++i) { - auto &node_item = *all_nodes[i]; + for (auto all_node : all_nodes) { + auto &node_item = *all_node; // for while op if (force_infer_shape_ && !node_item.is_dynamic) { GELOGD("[%s] Force infer shape is set, updating node to dynamic.", node_item.NodeName().c_str()); @@ -295,6 +295,7 @@ Status SubgraphExecutor::ScheduleTasks() { if (ret != SUCCESS) { GELOGE(ret, "[%s] Failed to execute subgraph.", graph_item_->GetName().c_str()); subgraph_context_->OnError(ret); + context_->SetErrorCode(ret); ready_queue_.Stop(); prepare_future.wait(); return ret; @@ -312,9 +313,13 @@ Status SubgraphExecutor::GetOutputs(vector &outputs, std::vectorGetName().c_str()); // copy output data from op to designated position - std::vector output_tensor_desc_list; GE_CHK_STATUS_RET(graph_item_->GetOutputDescList(output_desc), "[%s] Failed to get output tensor desc.", graph_item_->GetName().c_str()); + if (outputs.size() != output_desc.size()) { + GELOGE(INTERNAL_ERROR, "Number of output tensors(%zu) mismatch number of output tensor desc(%zu).", outputs.size(), + output_desc.size()); + return INTERNAL_ERROR; + } return SUCCESS; } diff --git a/src/ge/hybrid/executor/worker/execution_engine.cc b/src/ge/hybrid/executor/worker/execution_engine.cc index 20da6378..b19d0849 100644 --- a/src/ge/hybrid/executor/worker/execution_engine.cc +++ b/src/ge/hybrid/executor/worker/execution_engine.cc @@ -19,6 +19,9 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_adapter.h" #include "hybrid/node_executor/node_executor.h" +#include "common/dump/dump_manager.h" +#include "common/dump/dump_op.h" +#include "common/types.h" namespace ge { namespace hybrid { @@ -59,8 +62,10 @@ class NodeDoneCallback { private: Status PrepareConstInputs(const NodeItem &node_item); + Status DumpDynamicNode(); GraphExecutionContext *graph_context_; std::shared_ptr context_; + DumpOp dump_op_; }; NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr task_context) @@ -89,7 +94,7 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { return INTERNAL_ERROR; } - vector host_buffer(tensor_size); + vector host_buffer(static_cast(tensor_size)); GELOGD("[%s] To cache output[%d] to host, size = %zu", node_item.NodeName().c_str(), output_idx, output_tensor->GetSize()); GE_CHK_RT_RET( @@ -113,10 +118,78 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { return SUCCESS; } +Status NodeDoneCallback::DumpDynamicNode() { + auto node = context_->GetNodeItem().node; + if (node == nullptr) { + GELOGE(PARAM_INVALID, "Get node is nullptr"); + return PARAM_INVALID; + } + auto op_desc = node->GetOpDesc(); + auto stream = context_->GetStream(); + vector input_addrs; + vector output_addrs; + for (int i = 0; i < context_->NumInputs(); i++) { + auto tensor_value = context_->GetInput(i); + GE_CHK_BOOL_RET_STATUS(tensor_value != nullptr, PARAM_INVALID, "Tensor value is nullptr"); + uint64_t input_addr = reinterpret_cast(tensor_value->GetData()); + input_addrs.emplace_back(input_addr); + } + for (int j = 0; j < context_->NumOutputs(); j++) { + auto tensor_value = context_->GetOutput(j); + GE_CHK_BOOL_RET_STATUS(tensor_value != nullptr, PARAM_INVALID, "Tensor value is nullptr"); + uint64_t output_addr = reinterpret_cast(tensor_value->GetData()); + output_addrs.emplace_back(output_addr); + } + + dump_op_.SetDumpInfo(context_->GetDumpProperties(), op_desc, input_addrs, output_addrs, stream); + + GE_CHECK_NOTNULL(graph_context_); + const HybridModel *model = graph_context_->model; + GE_CHECK_NOTNULL(model); + std::string dynamic_model_name = model->GetModelName(); + uint32_t model_id = model->GetModelId(); + dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); + + void *global_step = nullptr; + TensorValue *varible_global_step = context_->GetVariable(NODE_NAME_GLOBAL_STEP); + if (varible_global_step != nullptr) { + global_step = const_cast(varible_global_step->GetData()); + } + + void *loop_per_iter = nullptr; + TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); + if (varible_loop_per_iter != nullptr) { + loop_per_iter = const_cast(varible_loop_per_iter->GetData()); + } + + void *loop_cond = nullptr; + TensorValue *varible_loop_cond = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_COND); + if (varible_loop_cond != nullptr) { + loop_cond = const_cast(varible_loop_cond->GetData()); + } + dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); + + GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); + + auto rt_ret = rtStreamSynchronize(stream); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "rtStreamSynchronize failed"); + return rt_ret; + } + return SUCCESS; +} + Status NodeDoneCallback::OnNodeDone() { auto &node_item = context_->GetNodeItem(); GELOGI("[%s] Start callback process.", node_item.NodeName().c_str()); - RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "Start"); + RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "[Compute] End"); + RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "[Callback] Start"); + + auto dump_path = context_->GetDumpProperties().GetDumpPath(); + if (!dump_path.empty()) { + GELOGI("Start to dump dynamic shape,dump_path is %s", dump_path.c_str()); + GE_CHK_STATUS_RET(DumpDynamicNode(), "Failed to dump dynamic node"); + } // release inputs for (int i = 0; i < context_->NumInputs(); ++i) { @@ -196,7 +269,12 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, TaskContext &task_ GE_CHK_STATUS_RET(ValidateInputTensors(node_state, task_context), "Failed to validate input tensors."); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ValidateInputTensors] End"); - GE_CHK_STATUS_RET(executor->ExecuteTask(*task, task_context, callback), "[%s] Failed to execute task", + if (context.profiling_level > 0) { + auto *ctx = &context; + const string &name = node_state.GetName(); + task_context.RegisterCallback([ctx, name]() { RECORD_CALLBACK_EVENT(ctx, name.c_str(), "[Compute] Start"); }); + } + GE_CHK_STATUS_RET(node_item.node_executor->ExecuteTask(*task, task_context, callback), "[%s] Failed to execute task", node_state.GetName().c_str()); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ExecuteTask] End"); diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.cc b/src/ge/hybrid/executor/worker/shape_inference_engine.cc index 650bcc54..49a29259 100644 --- a/src/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/src/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -32,6 +32,11 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { if (node_item.is_output_shape_static) { return SUCCESS; } + + if (node_item.fused_subgraph != nullptr) { + return InferShapeForSubgraph(node_item, *node_item.fused_subgraph); + } + // Skip shape inference for node of type DEPEND_COMPUTE if (node_item.shape_inference_type == DEPEND_COMPUTE) { GELOGD("[%s] Skipping node with unknown shape type DEPEND_COMPUTE", node_item.NodeName().c_str()); @@ -103,7 +108,7 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE; GELOGD("[%s] Start to propagate output shapes. shape_type = %d", node_item.NodeName().c_str(), node_item.shape_inference_type); - + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] Start"); // propagate each output for (int i = 0; i < node_item.num_outputs; ++i) { auto output_desc = node_item.op_desc->MutableOutputDesc(i); @@ -130,9 +135,77 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { } } } - + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] End"); GELOGD("[%s] Propagating output shapes finished successfully.", node_item.NodeName().c_str()); return SUCCESS; } + +Status ShapeInferenceEngine::InferShapeForSubgraph(const NodeItem &node_item, const FusedSubgraph &fused_subgraph) { + GELOGD("[%s] Start to infer shape by fused subgraph", node_item.NodeName().c_str()); + for (auto &it : fused_subgraph.input_mapping) { + auto parent_tensor_desc = node_item.op_desc->MutableInputDesc(it.first); + GE_CHECK_NOTNULL(parent_tensor_desc); + GELOGD("Start to update shape by input[%u]", it.first); + GELOGD("Update shape to [%s]", parent_tensor_desc->GetShape().ToString().c_str()); + GELOGD("Update original shape to [%s]", parent_tensor_desc->GetOriginShape().ToString().c_str()); + for (auto &tensor_desc : it.second) { + tensor_desc->SetShape(parent_tensor_desc->GetShape()); + tensor_desc->SetOriginShape(parent_tensor_desc->GetOriginShape()); + } + } + + for (auto &node : fused_subgraph.nodes) { + GELOGD("[%s] Start to invoke InferShapeAndType", node->GetName().c_str()); + GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndType(node)); + GELOGD("[%s] Done invoking InferShapeAndType", node->GetName().c_str()); + GE_CHK_STATUS_RET(UpdatePeerNodeShape(*node), "[%s] Failed to update shapes of peer node.", + node->GetName().c_str()); + } + + for (auto &it : fused_subgraph.output_mapping) { + uint32_t parent_output_idx = it.first; + const auto &op_desc = it.second; + GELOGD("Update parent output[%d] by [%s]", parent_output_idx, op_desc->GetName().c_str()); + auto input_desc = op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(input_desc); + auto parent_output_tensor_desc = node_item.op_desc->MutableOutputDesc(parent_output_idx); + GE_CHECK_NOTNULL(parent_output_tensor_desc); + GELOGD("Update shape to [%s]", input_desc->GetShape().ToString().c_str()); + GELOGD("Update original shape to [%s]", input_desc->GetOriginShape().ToString().c_str()); + parent_output_tensor_desc->SetOriginShape(input_desc->GetOriginShape()); + parent_output_tensor_desc->SetShape(input_desc->GetShape()); + } + + GELOGD("[%s] Done shape inference by subgraph successfully.", node_item.NodeName().c_str()); + return SUCCESS; +} + +Status ShapeInferenceEngine::UpdatePeerNodeShape(const Node &node) { + auto op_desc = node.GetOpDesc(); + for (const auto &out_anchor : node.GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { + auto peer_node = peer_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_node); + auto peer_op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + auto peer_input_desc = peer_op_desc->MutableInputDesc(peer_anchor->GetIdx()); + if (peer_input_desc == nullptr) { + GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); + continue; + } + + GELOGI("Peer input op desc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), + output_tensor->GetDataType(), output_tensor->GetOriginDataType()); + peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); + peer_input_desc->SetShape(output_tensor->GetShape()); + GELOGI("Peer input op desc name is %s, shape size is %zu, datatype is %d, original datatype is %d", + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), + peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); + } + } + return SUCCESS; +} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.h b/src/ge/hybrid/executor/worker/shape_inference_engine.h index 65878818..f8a391e2 100644 --- a/src/ge/hybrid/executor/worker/shape_inference_engine.h +++ b/src/ge/hybrid/executor/worker/shape_inference_engine.h @@ -30,9 +30,12 @@ class ShapeInferenceEngine { Status InferShape(NodeState &node_state); + Status InferShapeForSubgraph(const NodeItem &node_item, const FusedSubgraph &fused_subgraph); + Status PropagateOutputShapes(const NodeItem &node_item); private: + static Status UpdatePeerNodeShape(const Node &node); Status AwaitDependentNodes(NodeState &node_state); GraphExecutionContext *execution_context_; diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.cc b/src/ge/hybrid/executor/worker/task_compile_engine.cc index 57b19f5f..e2e94f66 100644 --- a/src/ge/hybrid/executor/worker/task_compile_engine.cc +++ b/src/ge/hybrid/executor/worker/task_compile_engine.cc @@ -22,12 +22,13 @@ namespace ge { namespace hybrid { Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { const auto &node_item = *node_state.GetNodeItem(); - RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "Start"); + GE_CHECK_NOTNULL(context); + RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); shared_ptr kernel_task; auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); - RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "End"); + RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End"); GE_CHK_STATUS_RET(ret, "Failed to create task for node: %s", node_item.NodeName().c_str()); node_state.SetKernelTask(kernel_task); GELOGI("Compiling node %s successfully", node_state.GetName().c_str()); diff --git a/src/ge/hybrid/model/hybrid_model.cc b/src/ge/hybrid/model/hybrid_model.cc index 0cb81aa3..18db28cb 100644 --- a/src/ge/hybrid/model/hybrid_model.cc +++ b/src/ge/hybrid/model/hybrid_model.cc @@ -118,5 +118,7 @@ const GraphItem *HybridModel::GetSubgraphItem(const ComputeGraphPtr &subgraph) c auto subgraph_name = subgraph->GetName(); return GetSubgraphItem(subgraph_name); } + +const string &HybridModel::GetModelName() const { return model_name_; } } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/model/hybrid_model.h b/src/ge/hybrid/model/hybrid_model.h index f554752e..668b5fd7 100644 --- a/src/ge/hybrid/model/hybrid_model.h +++ b/src/ge/hybrid/model/hybrid_model.h @@ -69,6 +69,8 @@ class HybridModel { const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; + const string &GetModelName() const; + private: friend class HybridModelBuilder; friend class HybridModelAsyncExecutor; diff --git a/src/ge/hybrid/model/hybrid_model_builder.cc b/src/ge/hybrid/model/hybrid_model_builder.cc index 436beada..97783711 100644 --- a/src/ge/hybrid/model/hybrid_model_builder.cc +++ b/src/ge/hybrid/model/hybrid_model_builder.cc @@ -16,6 +16,7 @@ #include "hybrid/model/hybrid_model_builder.h" #include "common/math/math_util.h" +#include "graph/ge_context.h" #include "graph/utils/node_utils.h" #include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/model_utils.h" @@ -184,10 +185,8 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s GELOGD("[%s] Input[%d] do not have peer anchor", node_item.NodeName().c_str(), in_anchor->GetIdx()); continue; } - auto src_node = peer_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); - auto src_node_item = MutableNodeItem(src_node); GE_CHECK_NOTNULL(src_node_item); @@ -207,6 +206,22 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s } } + // cond or branch need to be prepared before the execution of IF or CASE + if (node_item.node_type == IF || node_item.node_type == CASE) { + const auto &in_anchor = ge_node->GetInDataAnchor(0); + GE_CHECK_NOTNULL(in_anchor); + const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_anchor); + auto src_node = peer_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto src_node_item = MutableNodeItem(src_node); + GE_CHECK_NOTNULL(src_node_item); + src_node_item->has_observer = true; + node_item.dependents_for_execution.emplace_back(src_node); + GELOGD("[%s] Dependent added from %s for control op's cond/branch", node_item.NodeName().c_str(), + src_node_item->NodeName().c_str()); + } + for (const auto &input_name : dependencies) { int input_index = node_item.op_desc->GetInputIndexByName(input_name); if (input_index < 0) { @@ -416,18 +431,12 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap GE_CHECK_NOTNULL(op_desc); const auto &op_type = node->GetType(); - if (op_type == DATA || op_type == AIPP_DATA_TYPE || op_type == NETOUTPUT) { + if (op_type != PARTITIONEDCALL) { merged_graph->AddNode(node); GELOGD("[%s] Node added to merged graph.", op_desc->GetName().c_str()); continue; } - if (op_type != PARTITIONEDCALL) { - GELOGE(INTERNAL_ERROR, "[%s] Unexpected node in root graph. type = %s", op_desc->GetName().c_str(), - op_type.c_str()); - return INTERNAL_ERROR; - } - bool is_unknown_shape = false; GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), "Failed to invoke GetNodeUnknownShapeStatus."); @@ -538,14 +547,16 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, const NodeI Status HybridModelBuilder::LoadGraph() { auto root_graph = ge_root_model_->GetRootGraph(); - std::shared_ptr merged_graph; - GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), - root_graph->GetAllNodesSize()); - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); - root_graph = std::move(merged_graph); - GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), - root_graph->GetAllNodesSize()); - GE_DUMP(root_graph, "hybrid_merged_graph"); + if (!GetContext().GetHostExecFlag()) { + std::shared_ptr merged_graph; + GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), + root_graph->GetAllNodesSize()); + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); + root_graph = std::move(merged_graph); + GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), + root_graph->GetAllNodesSize()); + GE_DUMP(root_graph, "hybrid_merged_graph"); + } GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); GELOGD("Done loading root graph successfully."); @@ -641,6 +652,10 @@ Status HybridModelBuilder::HandleDtString(const GeTensor &tensor, void *var_addr } Status HybridModelBuilder::AssignUninitializedConstantOps() { + if (GetContext().GetHostExecFlag()) { + GELOGI("no need to assign when exec on host."); + return SUCCESS; + } for (auto &it : hybrid_model_.constant_op_nodes_) { const string &var_name = it.first; const NodePtr &var_node = it.second; @@ -820,6 +835,7 @@ Status HybridModelBuilder::IndexSpecialNodes() { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); auto op_type = node->GetType(); + GELOGD("node name = %s, node type = %s", node->GetName().c_str(), node->GetType().c_str()); if (op_type == VARIABLE) { hybrid_model_.variable_nodes_.emplace(node->GetName(), node); } else if (op_type == CONSTANTOP) { @@ -937,12 +953,19 @@ Status HybridModelBuilder::InitRuntimeParams() { // session id and var size is same for every model auto first_model = ge_root_model_->GetSubgraphInstanceNameToModel().begin()->second; ret = ge::AttrUtils::GetInt(first_model, ge::MODEL_ATTR_SESSION_ID, value); - runtime_param_.session_id = ret ? (uint64_t)value : 0; + runtime_param_.session_id = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); - runtime_param_.logic_var_base = ret ? (uint64_t)value : 0; - ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_VAR_SIZE, value); - runtime_param_.var_size = ret ? (uint64_t)value : 0; + runtime_param_.logic_var_base = ret ? static_cast(value) : 0; runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); + value = 0; + for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { + (void)ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); + if (value > 0) { + runtime_param_.var_size = static_cast(value); + break; + } + } + GELOGI("InitRuntimeParams(), session_id:%lu, var_size:%lu. graph_id = %u", runtime_param_.session_id, runtime_param_.var_size, runtime_param_.graph_id); @@ -1178,6 +1201,8 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root } graph_item->node_items_.emplace_back(node_item); + // parse var outputs + GE_CHK_STATUS_RET_NOLOG(ParseVarOutputs(*node_item)); GELOGD("NodeItem created: %s", node_item->DebugString().c_str()); } @@ -1197,6 +1222,20 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root return SUCCESS; } +Status HybridModelBuilder::ParseVarOutputs(NodeItem &node_item) { + for (int i = 0; i < node_item.num_outputs; ++i) { + auto output_tensor_desc = node_item.op_desc->GetOutputDesc(i); + std::string var_name; + (void)AttrUtils::GetStr(output_tensor_desc, ASSIGN_VAR_NAME, var_name); + if (!var_name.empty()) { + auto var_node = hybrid_model_.GetVariableNode(var_name); + GE_CHECK_NOTNULL(var_node); + node_item.ref_outputs.emplace(i, var_node); + } + } + return SUCCESS; +} + Status HybridModelBuilder::BuildInputMapping(GraphItem &graph_item, vector &data_nodes, bool is_root_graph) { uint32_t data_op_index = 0; diff --git a/src/ge/hybrid/model/hybrid_model_builder.h b/src/ge/hybrid/model/hybrid_model_builder.h index 1103aa1c..ecd327ff 100644 --- a/src/ge/hybrid/model/hybrid_model_builder.h +++ b/src/ge/hybrid/model/hybrid_model_builder.h @@ -71,6 +71,7 @@ class HybridModelBuilder { Status InitConstantOps(); Status InitVariableTensors(); Status LoadDynamicSubgraph(ComputeGraph &graph, bool is_root_graph); + Status ParseVarOutputs(NodeItem &node_item); Status LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item); const char *GetGraphName() const { return hybrid_model_.model_name_.c_str(); } diff --git a/src/ge/hybrid/model/node_item.cc b/src/ge/hybrid/model/node_item.cc index bfc29c84..7ec8d946 100644 --- a/src/ge/hybrid/model/node_item.cc +++ b/src/ge/hybrid/model/node_item.cc @@ -17,12 +17,85 @@ #include "node_item.h" #include #include "common/debug/log.h" +#include "graph/common/omg_util.h" +#include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/node_utils.h" #include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { +namespace { +const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; +const char *const kNodeTypeRetVal = "_RetVal"; + +Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "[%s] Failed to get attr [%s]", op_desc.GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return FAILED; + } + + for (auto &node_and_anchor : node.GetOutDataNodesAndAnchors()) { + auto dst_op_desc = node_and_anchor.first->GetOpDesc(); + GE_CHECK_NOTNULL(dst_op_desc); + auto in_idx = node_and_anchor.second->GetIdx(); + auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx); + fused_subgraph.input_mapping[parent_index].emplace_back(tensor_desc); + GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx); + } + + return SUCCESS; +} + +Status ParseOutputMapping(OpDescPtr op_desc, FusedSubgraph &fused_subgraph) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "[%s] Failed to get attr [%s]", op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return FAILED; + } + + fused_subgraph.output_mapping.emplace(parent_index, op_desc); + return SUCCESS; +} + +Status ParseFusedSubgraph(NodeItem &node_item) { + if (!node_item.op_desc->HasAttr(kAttrNameOriginalFusionGraph)) { + return SUCCESS; + } + + GELOGI("[%s] Start to parse fused subgraph.", node_item.node_name.c_str()); + auto fused_subgraph = std::unique_ptr(new (std::nothrow) FusedSubgraph()); + GE_CHECK_NOTNULL(fused_subgraph); + + ComputeGraphPtr fused_graph; + (void)AttrUtils::GetGraph(*node_item.op_desc, kAttrNameOriginalFusionGraph, fused_graph); + GE_CHECK_NOTNULL(fused_graph); + + fused_graph->SetGraphUnknownFlag(true); + fused_subgraph->graph = fused_graph; + GE_CHK_GRAPH_STATUS_RET(fused_graph->TopologicalSorting()); + + for (auto &node : fused_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::string node_type; + GE_CHK_STATUS_RET(GetOriginalType(node, node_type)); + if (node_type == DATA) { + GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); + } else if (node_type == kNodeTypeRetVal) { + GE_CHK_GRAPH_STATUS_RET(ParseOutputMapping(op_desc, *fused_subgraph)); + } else { + fused_subgraph->nodes.emplace_back(node); + } + } + + node_item.fused_subgraph = std::move(fused_subgraph); + GELOGI("[%s] Done parsing fused subgraph successfully.", node_item.NodeName().c_str()); + return SUCCESS; +} +} // namespace NodeItem::NodeItem(NodePtr node) : node(std::move(node)) { this->op_desc = this->node->GetOpDesc().get(); this->node_id = this->op_desc->GetId(); @@ -39,7 +112,7 @@ Status NodeItem::Init() { GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_dynamic), "[%s] Failed to get shape status.", node->GetName().c_str()); - + GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); if (is_dynamic) { for (int i = 0; i < num_inputs; ++i) { const auto &input_desc = op_desc->MutableInputDesc(i); diff --git a/src/ge/hybrid/model/node_item.h b/src/ge/hybrid/model/node_item.h index ff024b36..53cdeca6 100644 --- a/src/ge/hybrid/model/node_item.h +++ b/src/ge/hybrid/model/node_item.h @@ -29,6 +29,13 @@ namespace hybrid { class NodeTask; class NodeExecutor; +struct FusedSubgraph { + std::map> input_mapping; + std::map output_mapping; + std::vector nodes; + ComputeGraphPtr graph; +}; + // for caching static information across execution struct NodeItem { explicit NodeItem(NodePtr node); @@ -42,8 +49,6 @@ struct NodeItem { bool IsControlOp() const; - bool NeedInfershape() const; - void SetToDynamic(); std::string DebugString() const; @@ -70,6 +75,7 @@ struct NodeItem { vector>> outputs; std::shared_ptr kernel_task; + std::unique_ptr fused_subgraph; const NodeExecutor *node_executor = nullptr; std::map ref_outputs; std::map reuse_inputs; diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 71280649..942d6d9e 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -18,6 +18,7 @@ #include "cce/taskdown_common.hpp" #include "hybrid/executor/hybrid_execution_context.h" #include "init/gelib.h" +#include "hybrid/executor/hybrid_execution_context.h" namespace ge { namespace hybrid { @@ -153,18 +154,25 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, const NodePtr & } Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] Start"); auto op_desc = context.GetNodeItem().op_desc; GE_CHECK_NOTNULL(op_desc); GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); for (auto &task : tasks_) { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start"); GE_CHK_STATUS_RET_NOLOG(task->LaunchKernel(context.GetStream())); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); } if (done_callback != nullptr) { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeRegisterCallback] Start"); GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeRegisterCallback] End"); } GELOGD("[%s] ExecuteAsync End.", op_desc->GetName().c_str()); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] End"); return SUCCESS; } diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc index 46d9a0aa..44fe377a 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -165,6 +165,7 @@ Status AicpuNodeTaskBase::UpdateArgs(TaskContext &context) { } Status AicpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::function done_callback) { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AicpuNodeTaskBaseExecuteAsync] Start"); GELOGI("Node[%s] execute async start. unknown_type=%d.", node_name_.c_str(), unknown_type_); GE_CHK_STATUS_RET(LaunchTask(context)); @@ -187,6 +188,7 @@ Status AicpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::functionGetData(), kernel_buf_->GetSize(), flag, context.GetStream())); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[AicpuTfNodertKernelLaunchEx] End"); GELOGI("Node[%s] launch end.", node_name_.c_str()); return SUCCESS; } @@ -704,7 +708,10 @@ Status AicpuNodeTask::TaskCallback(TaskContext &context) { Status AiCpuNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { // malloc HBM memory at Init, here just update them - return task.UpdateArgs(context); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCpuNodeExecutorPrepareTask] Start"); + Status status = task.UpdateArgs(context); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCpuNodeExecutorPrepareTask] End"); + return status; } Status AiCpuNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, diff --git a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc index 2e1893f2..122af0f5 100644 --- a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -23,12 +23,14 @@ #include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/model_utils.h" #include "graph/load/new_model_manager/model_manager.h" +#include "hybrid/executor/hybrid_execution_context.h" namespace ge { namespace hybrid { REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH, KnownNodeExecutor); Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] Start"); GELOGI("[%s] KnownNodeTask::ExecuteAsync in.", context.GetNodeName()); if (davinci_model_->GetTaskList().size() == 0) { GELOGW("KnownNodeExecutor::ExecuteAsync davinci moel has no taskinfo."); @@ -43,18 +45,23 @@ Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function d } } - context.RegisterCallback(done_callback); + GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); return SUCCESS; } rtError_t rt_ret; GELOGI("rtModelExecute start."); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] Start"); rt_ret = rtModelExecute(davinci_model_->GetRtModelHandle(), context.GetStream(), 0); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtModelExecute error, ret: Ox%X", rt_ret); return FAILED;); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, + GELOGE(rt_ret, "rtModelExecute error, ret: hybrid_model_executorOx%X", rt_ret); + return FAILED;); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] End"); GELOGI("rtModelExecute end"); - context.RegisterCallback(done_callback); + GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); GELOGI("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName()); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] End"); return SUCCESS; } @@ -99,9 +106,13 @@ Status KnownNodeTask::Init(TaskContext &context) { // allocate mem base void *buffer = nullptr; if (davinci_model_->TotalMemSize() != 0) { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), + "[KnownNodeTask_AllocateWorkspace] Start"); GE_CHK_STATUS_RET( context.AllocateWorkspace(davinci_model_->TotalMemSize(), &buffer, davinci_model_->GetRuntimeParam().mem_base), "known node task allocate workspace failed."); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), + "[KnownNodeTask_AllocateWorkspace] End, size %zu", davinci_model_->TotalMemSize()); bool addr_not_changed = false; if (davinci_model_->GetRuntimeParam().mem_base == buffer) { addr_not_changed = true; @@ -126,9 +137,15 @@ Status KnownNodeTask::Init(TaskContext &context) { Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GELOGI("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName()); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start"); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorTaskInit] Start"); GE_CHK_STATUS_RET(task.Init(context), "known node init davinci model failed."); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorTaskInit] End"); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorUpdateArgs] Start"); GE_CHK_STATUS_RET(task.UpdateArgs(context), "known node task update args failed."); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorUpdateArgs] End"); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] End"); GELOGI("[%s] KnownNodeExecutor::PrepareTask success.", context.GetNodeName()); return SUCCESS; } @@ -145,8 +162,9 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node // set known node flag as true davinci_model->SetKnownNode(true); - // set model id - davinci_model->SetId(model.GetModelId()); + // set model id as root node's node id + davinci_model->SetId(node->GetOpDesc()->GetId()); + GELOGD("KnownNodeExecutor::LoadTask node id %u.", node->GetOpDesc()->GetId()); GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); @@ -158,8 +176,10 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] Start"); GE_CHK_STATUS_RET(task.ExecuteAsync(context, callback), "Failed to execute task. node = %s", context.GetNodeItem().NodeName().c_str()); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); return SUCCESS; } } // namespace hybrid diff --git a/src/ge/hybrid/node_executor/controlop/control_op_executor.cc b/src/ge/hybrid/node_executor/controlop/control_op_executor.cc index aee7fb77..2bf7407c 100644 --- a/src/ge/hybrid/node_executor/controlop/control_op_executor.cc +++ b/src/ge/hybrid/node_executor/controlop/control_op_executor.cc @@ -16,12 +16,21 @@ #include "control_op_executor.h" #include "graph/utils/node_utils.h" +#include "graph/utils/type_utils.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/subgraph_executor.h" namespace ge { namespace hybrid { REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::CONTROL_OP, ControlOpNodeExecutor); +namespace { +template +Status CopyScalarValueToHost(const TensorValue &tensor, T &value) { + GE_CHECK_GE(tensor.GetSize(), sizeof(value)); + GE_CHK_RT_RET(rtMemcpy(&value, sizeof(value), tensor.GetData(), sizeof(value), RT_MEMCPY_DEVICE_TO_HOST)); + return SUCCESS; +} +} // namespace Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph, TaskContext &task_context, const std::function &done_callback) { @@ -45,10 +54,32 @@ Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph, TaskContext return SUCCESS; } -Status ControlOpNodeTask::CopyTensorValueToHost(const TensorValue &tensor, int32_t &value) { - GE_CHECK_NOTNULL(tensor.GetData()); - GE_CHECK_GE(tensor.GetSize(), sizeof(value)); - GE_CHK_RT_RET(rtMemcpy(&value, sizeof(value), tensor.GetData(), sizeof(value), RT_MEMCPY_DEVICE_TO_HOST)); +Status ControlOpNodeTask::ToBool(const TensorValue &tensor, DataType data_type, bool &value) { + switch (data_type) { +#define CASE(DT, T) \ + case (DT): { \ + T val{}; \ + GE_CHK_STATUS_RET(CopyScalarValueToHost(tensor, val)); \ + value = val != 0; \ + break; \ + } + // DT_STRING was handled in CondPass + CASE(DT_FLOAT, float) + CASE(DT_DOUBLE, double) + CASE(DT_INT32, int32_t) + CASE(DT_UINT8, uint8_t) + CASE(DT_INT16, int16_t) + CASE(DT_INT8, int8_t) + CASE(DT_INT64, int64_t) +#undef CASE + case DT_BOOL: + GE_CHK_STATUS_RET(CopyScalarValueToHost(tensor, value)); + break; + default: + GELOGE(UNSUPPORTED, "Data type %s is not support by cond.", TypeUtils::DataTypeToSerialString(data_type).c_str()); + return UNSUPPORTED; + } + return SUCCESS; } @@ -60,11 +91,6 @@ Status ControlOpNodeTask::UpdateArgs(TaskContext &context) { Status ControlOpNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { auto ret = DoExecuteAsync(task_context, done_callback); task_context.SetStatus(ret); - - if (done_callback) { - done_callback(); - } - return ret; } @@ -86,16 +112,24 @@ Status IfOpNodeTask::Init(const NodePtr &node, const HybridModel &model) { return SUCCESS; } -const GraphItem *IfOpNodeTask::SelectBranch(int32_t cond) const { return cond != 0 ? then_ : else_; } - Status IfOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const { - auto cond_tensor = task_context.GetInput(kIfCondIndex); - GE_CHECK_NOTNULL(cond_tensor); - int32_t cond_val = 0; - GE_CHK_STATUS_RET(CopyTensorValueToHost(*cond_tensor, cond_val), "[%s] Failed to get cond value.", - task_context.GetNodeName()); + auto cond_tensor_desc = task_context.MutableInputDesc(kIfCondIndex); + auto data_type = cond_tensor_desc->GetDataType(); + const auto &shape = cond_tensor_desc->MutableShape(); + bool cond_val = false; + if (shape.IsScalar()) { + auto cond_tensor = task_context.GetInput(kIfCondIndex); + GE_CHECK_NOTNULL(cond_tensor); + GE_CHK_STATUS_RET(ToBool(*cond_tensor, data_type, cond_val), "[%s] Failed to get cond value.", + task_context.GetNodeName()); + } else { + // true if num elements is non-zero + cond_val = shape.GetShapeSize() != 0; + GELOGD("[%s] Cond tensor shape = [%s], cond value = %d", task_context.GetNodeName(), shape.ToString().c_str(), + cond_val); + } - auto subgraph = SelectBranch(cond_val); + auto subgraph = cond_val ? then_ : else_; GELOGD("[%s] Taking subgraph [%s] by cond = [%d]", task_context.GetNodeName(), subgraph->GetName().c_str(), cond_val); GE_CHK_STATUS_RET(ExecuteSubgraph(subgraph, task_context, done_callback), "[%s] Failed to execute subgraph. cond = %d", task_context.GetNodeName(), cond_val); @@ -139,9 +173,7 @@ Status CaseOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::func auto branch_tensor = task_context.GetInput(kCaseBranchIndex); GE_CHECK_NOTNULL(branch_tensor); int32_t branch_index = 0; - GE_CHK_STATUS_RET(CopyTensorValueToHost(*branch_tensor, branch_index), "[%s] Failed to get branch index.", - task_context.GetNodeName()); - + GE_CHK_STATUS_RET(CopyScalarValueToHost(*branch_tensor, branch_index)); const GraphItem *subgraph = SelectBranch(branch_index); GELOGI("[%s] Taking subgraph [%s] by branch = [%d]", task_context.GetNodeName(), subgraph->GetName().c_str(), branch_index); @@ -201,6 +233,9 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun *output_tensor_desc = *input_tensor_desc; } + if (done_callback) { + done_callback(); + } return SUCCESS; } @@ -236,6 +271,9 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); } + if (done_callback) { + done_callback(); + } return SUCCESS; } @@ -261,16 +299,28 @@ Status WhileOpNodeTask::ExecuteCond(TaskContext &task_context, bool &is_continue // get cond output GE_CHK_STATUS_RET(executor->Synchronize(), "[%s] Failed to sync cond-subgraph result.", cond_->GetName().c_str()); std::vector cond_outputs; - GE_CHK_STATUS_RET(executor->GetOutputs(cond_outputs), "[%s] Failed to get cond-output.", cond_->GetName().c_str()); - if (cond_outputs.empty()) { - GELOGE(INTERNAL_ERROR, "[%s] Cond output is empty.", task_context.GetNodeName()); + std::vector cond_output_desc_list; + GE_CHK_STATUS_RET(executor->GetOutputs(cond_outputs, cond_output_desc_list), "[%s] Failed to get cond-output.", + cond_->GetName().c_str()); + if (cond_outputs.size() != kCondOutputSize || cond_output_desc_list.size() != kCondOutputSize) { + GELOGE(INTERNAL_ERROR, "[%s] Number of cond outputs is invalid. number = %zu", task_context.GetNodeName(), + cond_outputs.size()); return INTERNAL_ERROR; } - int cond_val = 0; - GE_CHK_STATUS_RET(CopyTensorValueToHost(cond_outputs[0], cond_val), "[%s] Failed to get cond result.", - task_context.GetNodeName()); - is_continue = cond_val != 0; + auto &cond_tensor_desc = cond_output_desc_list[0]; + const auto &shape = cond_tensor_desc->GetShape(); + if (shape.IsScalar()) { + auto data_type = cond_tensor_desc->GetDataType(); + GE_CHK_STATUS_RET(ToBool(cond_outputs[0], data_type, is_continue), "[%s] Failed to get cond value.", + task_context.GetNodeName()); + } else { + // true if num elements is non-zero + is_continue = shape.GetShapeSize() > 0; + GELOGD("[%s] Cond tensor shape = [%s], is_continue = %d", task_context.GetNodeName(), shape.ToString().c_str(), + is_continue); + } + return SUCCESS; } diff --git a/src/ge/hybrid/node_executor/controlop/control_op_executor.h b/src/ge/hybrid/node_executor/controlop/control_op_executor.h index 0619c6a0..68db7e91 100644 --- a/src/ge/hybrid/node_executor/controlop/control_op_executor.h +++ b/src/ge/hybrid/node_executor/controlop/control_op_executor.h @@ -32,7 +32,7 @@ class ControlOpNodeTask : public NodeTask { protected: virtual Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const = 0; - static Status CopyTensorValueToHost(const TensorValue &tensor_value, int32_t &value); + static Status ToBool(const TensorValue &tensor_value, DataType data_type, bool &value); static Status ExecuteSubgraph(const GraphItem *subgraph, TaskContext &task_context, const std::function &done_callback); }; @@ -42,7 +42,6 @@ class IfOpNodeTask : public ControlOpNodeTask { Status Init(const NodePtr &node, const HybridModel &model) override; protected: - const GraphItem *SelectBranch(int32_t cond) const; Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const override; private: @@ -85,6 +84,7 @@ class WhileOpNodeTask : public ControlOpNodeTask { private: static constexpr int kCondBranchIndex = 0; static constexpr int kBodyBranchIndex = 1; + static constexpr size_t kCondOutputSize = 1; const GraphItem *cond_ = nullptr; const GraphItem *body_ = nullptr; diff --git a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc b/src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc similarity index 88% rename from src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc rename to src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc index 7cd10a83..6cf7363e 100644 --- a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc +++ b/src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc @@ -14,12 +14,14 @@ * limitations under the License. */ -#include "hybrid/node_executor/hostcpu/ge_local_node_executor.h" +#include "hybrid/node_executor/ge_local/ge_local_node_executor.h" #include "graph/debug/ge_attr_define.h" #include "framework/common/util.h" #include "hybrid/model/hybrid_model.h" #include "inc/kernel.h" #include "inc/kernel_factory.h" +#include "common/ge/ge_util.h" +#include "hybrid/executor/hybrid_execution_context.h" namespace ge { namespace hybrid { @@ -91,12 +93,14 @@ Status RefInputTask::RefByOrder(const std::vector &ref_order, TaskCont } Status RefInputTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[RefInputTaskExecuteAsync] Start"); GE_CHK_STATUS_RET(Execute(context), "node:%s type:%s ref input task execute failed", node_name_.c_str(), node_type_.c_str()); if (done_callback != nullptr) { // host cpu no need register callback, call it directly. - done_callback(); + GE_CHK_STATUS_RET(context.TryExecuteCallback(done_callback)); } + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[RefInputTaskExecuteAsync] End"); return SUCCESS; } @@ -159,18 +163,28 @@ Status DependInputShapeTask::Execute(TaskContext &context) { } Status DependInputShapeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), + "[DependInputShapeTaskExecuteAsync] Start"); GE_CHK_STATUS_RET(Execute(context), "node:%s type:%s depend input shape task execute failed", node_->GetName().c_str(), node_->GetType().c_str()); if (done_callback != nullptr) { // host cpu no need register callback, call it directly. - done_callback(); + GE_CHK_STATUS_RET(context.TryExecuteCallback(done_callback)); } + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), + "[DependInputShapeTaskExecuteAsync] End"); return SUCCESS; } bool DependInputShapeTask::IsBelong(const std::string &op_type) { return depend_input_shape_ops_.count(op_type) > 0; } -Status GeLocalNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { return task.UpdateArgs(context); } +Status GeLocalNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), + "[GeLocalNodeExecutorPrepareTask] Start"); + Status status = task.UpdateArgs(context); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[GeLocalNodeExecutorPrepareTask] End"); + return status; +} Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const { @@ -226,4 +240,4 @@ Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function do GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclRootId(op_desc, root_id), "GetHcclRootId failed"); } op_info.root = root_id; - auto callback = [this](hcclResult_t status) { + auto callback = [this, op_desc](hcclResult_t status) { if (status != HCCL_SUCCESS) { - GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", status); + GELOGE(HCCL_E_INTERNAL, "node %s call EnqueueHcomOpertion failed, ret: 0x%X", op_desc->GetName().c_str(), status); } std::lock_guard lock(this->hccl_mutex_); this->cond_.notify_all(); - GELOGI("hccl callback success."); + GELOGI("node %s hccl callback success.", op_desc->GetName().c_str()); }; int32_t count = 0; - GE_CHK_STATUS_RET(HcomOmeUtil::GetHcomCount(op_desc, static_cast(op_info.dataType), false, count), + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcomCount(op_desc, static_cast(op_info.dataType), + op_desc->GetType() == HCOMALLGATHER, count), "GetHcomCount failed"); GELOGI("[%s] HcclNodeTask::ExecuteAsync hccl_type %s, count %d, data_type %d, op_type %d, root %d.", context.GetNodeName(), op_info.hcclType.c_str(), count, op_info.dataType, op_info.opType, op_info.root); @@ -110,7 +111,7 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function do std::unique_lock ulock(hccl_mutex_); cond_.wait(ulock); - context.RegisterCallback(done_callback); + GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); GELOGI("[%s] HcclNodeTask::ExecuteAsync success.", context.GetNodeName()); return SUCCESS; } diff --git a/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.cc b/src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc similarity index 77% rename from src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.cc rename to src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc index 4798b87e..fbad1fcd 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.cc +++ b/src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "hybrid/node_executor/hostaicpu/host_aicpu_node_executor.h" -#include "hybrid/node_executor/hostaicpu/kernel_factory.h" +#include "hybrid/node_executor/host_cpu/host_cpu_node_executor.h" +#include "hybrid/node_executor/host_cpu/kernel_factory.h" #include "graph/passes/folding_pass.h" #include "hybrid/model/hybrid_model.h" #include "inc/kernel_factory.h" @@ -23,14 +23,14 @@ namespace ge { namespace hybrid { -REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HOST_AICPU, HostAiCpuNodeExecutor); +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HOST_CPU, HostCpuNodeExecutor); -Status HostCpuNodeTaskBase::UpdateArgs(TaskContext &) { +Status HostNodeTaskBase::UpdateArgs(TaskContext &) { // no need update args return SUCCESS; } -Status HostCpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::function done_callback) { +Status HostNodeTaskBase::ExecuteAsync(TaskContext &context, std::function done_callback) { GELOGD("[%s] Start execute.", context.GetNodeName()); std::vector inputs; @@ -50,7 +50,7 @@ Status HostCpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::function &inputs) { +Status HostNodeTaskBase::ProcessInputs(TaskContext &context, std::vector &inputs) { int32_t input_num = context.NumInputs(); for (auto i = 0; i < input_num; ++i) { auto tensor_value = context.GetInput(i); @@ -67,7 +67,7 @@ Status HostCpuNodeTaskBase::ProcessInputs(TaskContext &context, std::vector &outputs) { +Status HostNodeTaskBase::ProcessOutputs(TaskContext &context, std::vector &outputs) { int32_t output_num = context.NumOutputs(); if (static_cast(output_num) != outputs.size()) { GELOGE(INTERNAL_ERROR, "node %s type %s has %d output, but kernel compute only has %zu output.", @@ -136,14 +136,14 @@ Status HostKernelNodeTask::Execute(TaskContext &context, const std::vector &inputs) { return SUCCESS; } +Status HostCpuNodeTask::ProcessInputs(TaskContext &context, std::vector &inputs) { return SUCCESS; } -Status HostAiCpuNodeTask::ProcessOutputs(TaskContext &context, std::vector &outputs) { return SUCCESS; } +Status HostCpuNodeTask::ProcessOutputs(TaskContext &context, std::vector &outputs) { return SUCCESS; } -Status HostAiCpuNodeTask::Execute(TaskContext &context, const std::vector &inputs, - std::vector &outputs) { +Status HostCpuNodeTask::Execute(TaskContext &context, const std::vector &inputs, + std::vector &outputs) { RunContext run_context; - auto host_kernel = hybrid::host_aicpu::KernelFactory::Instance().CreateKernel(node_); + auto host_kernel = hybrid::host_cpu::KernelFactory::Instance().CreateKernel(node_); if (host_kernel == nullptr) { GELOGE(UNSUPPORTED, "node %s type %s is not supported by host kernel.", node_->GetName().c_str(), node_->GetType().c_str()); @@ -160,13 +160,15 @@ Status HostAiCpuNodeTask::Execute(TaskContext &context, const std::vector &task) const { +Status HostCpuNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, + std::shared_ptr &task) const { GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto mem_type = static_cast(HOST_DDR); + (void)AttrUtils::SetInt(op_desc, ATTR_OUTPUT_MEMORY_TYPE, mem_type); const std::string &name = node->GetName(); const std::string &type = node->GetType(); if (HostCpuEngine::GetInstance().CheckSupported(type)) { @@ -177,12 +179,12 @@ Status HostAiCpuNodeExecutor::LoadTask(const HybridModel &model, const NodePtr & GELOGI("create HostKernelNodeTask for node %s, type %s.", name.c_str(), type.c_str()); task = MakeShared(node); GE_CHECK_NOTNULL(task); - } else if (hybrid::host_aicpu::KernelFactory::Instance().CreateKernel(node) != nullptr) { - GELOGI("create HostAiCpuNodeTask for node %s, type %s.", name.c_str(), type.c_str()); - task = MakeShared(node); + } else if (hybrid::host_cpu::KernelFactory::Instance().CreateKernel(node) != nullptr) { + GELOGI("create HostCpuNodeTask for node %s, type %s.", name.c_str(), type.c_str()); + task = MakeShared(node); GE_CHECK_NOTNULL(task); } else { - GELOGE(UNSUPPORTED, "node %s type %s is not support in HostAiCpuNodeExecutor now.", name.c_str(), type.c_str()); + GELOGE(UNSUPPORTED, "node %s type %s is not support in HostCpuNodeExecutor now.", name.c_str(), type.c_str()); return UNSUPPORTED; } return SUCCESS; diff --git a/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.h b/src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h similarity index 68% rename from src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.h rename to src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h index 406d1597..b27e558b 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.h +++ b/src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h @@ -14,18 +14,18 @@ * limitations under the License. */ -#ifndef GE_HYBRID_KERNEL_HOST_AICPU_NODE_EXECUTOR_H_ -#define GE_HYBRID_KERNEL_HOST_AICPU_NODE_EXECUTOR_H_ +#ifndef GE_HYBRID_KERNEL_HOST_CPU_NODE_EXECUTOR_H_ +#define GE_HYBRID_KERNEL_HOST_CPU_NODE_EXECUTOR_H_ #include "inc/kernel.h" #include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { -class HostCpuNodeTaskBase : public NodeTask { +class HostNodeTaskBase : public NodeTask { public: - explicit HostCpuNodeTaskBase(const NodePtr &node) : node_(node) {} - ~HostCpuNodeTaskBase() = default; + explicit HostNodeTaskBase(const NodePtr &node) : node_(node) {} + ~HostNodeTaskBase() = default; virtual Status UpdateArgs(TaskContext &context); virtual Status ExecuteAsync(TaskContext &context, std::function done_callback); @@ -39,9 +39,9 @@ class HostCpuNodeTaskBase : public NodeTask { virtual Status ProcessOutputs(TaskContext &context, std::vector &outputs); }; -class CpuKernelNodeTask : public HostCpuNodeTaskBase { +class CpuKernelNodeTask : public HostNodeTaskBase { public: - explicit CpuKernelNodeTask(const NodePtr &node) : HostCpuNodeTaskBase(node) {} + explicit CpuKernelNodeTask(const NodePtr &node) : HostNodeTaskBase(node) {} ~CpuKernelNodeTask() = default; private: @@ -49,9 +49,9 @@ class CpuKernelNodeTask : public HostCpuNodeTaskBase { std::vector &outputs) override; }; -class HostKernelNodeTask : public HostCpuNodeTaskBase { +class HostKernelNodeTask : public HostNodeTaskBase { public: - explicit HostKernelNodeTask(const NodePtr &node) : HostCpuNodeTaskBase(node) {} + explicit HostKernelNodeTask(const NodePtr &node) : HostNodeTaskBase(node) {} ~HostKernelNodeTask() = default; private: @@ -59,10 +59,10 @@ class HostKernelNodeTask : public HostCpuNodeTaskBase { std::vector &outputs) override; }; -class HostAiCpuNodeTask : public HostCpuNodeTaskBase { +class HostCpuNodeTask : public HostNodeTaskBase { public: - explicit HostAiCpuNodeTask(const NodePtr &node) : HostCpuNodeTaskBase(node) {} - ~HostAiCpuNodeTask() = default; + explicit HostCpuNodeTask(const NodePtr &node) : HostNodeTaskBase(node) {} + ~HostCpuNodeTask() = default; private: Status Execute(TaskContext &context, const std::vector &inputs, @@ -71,13 +71,12 @@ class HostAiCpuNodeTask : public HostCpuNodeTaskBase { Status ProcessOutputs(TaskContext &context, std::vector &outputs) override; }; -class HostAiCpuNodeExecutor : public NodeExecutor { +class HostCpuNodeExecutor : public NodeExecutor { public: Status PrepareTask(NodeTask &task, TaskContext &context) const override; - virtual Status LoadTask(const HybridModel &model, const NodePtr &node, - std::shared_ptr &task) const override; + Status LoadTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const override; }; } // namespace hybrid } // namespace ge -#endif // GE_HYBRID_KERNEL_HOST_AICPU_NODE_EXECUTOR_H_ +#endif // GE_HYBRID_KERNEL_HOST_CPU_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.cc b/src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc similarity index 82% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc index 02ce40e2..3655fcdb 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.cc +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "hybrid/node_executor/hostaicpu/kernel/assign_kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/assign_kernel.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" -#include "hybrid/node_executor/hostaicpu/kernel_factory.h" +#include "hybrid/node_executor/host_cpu/kernel_factory.h" namespace { const size_t kAssignInputNum = 2; @@ -28,9 +28,9 @@ const size_t kAssignRefOutputIndex = 0; namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { Status AssignKernel::Compute(TaskContext& context) { - GELOGI("AssignKernel [%s, %s] compute begin.", node_->GetName().c_str(), node_->GetType().c_str()); + GELOGI("[%s] compute begin.", node_->GetName().c_str()); auto ref_tensor = context.MutableInput(kAssignRefInputIndex); GE_CHECK_NOTNULL(ref_tensor); @@ -43,7 +43,7 @@ Status AssignKernel::Compute(TaskContext& context) { } GELOGI("[%s] value_input_data=%p, ref_input_size=%zu, value_input_size=%zu.", node_->GetName().c_str(), - ref_tensor->GetSize(), ref_tensor->GetData(), value_tensor->GetSize()); + ref_tensor->GetData(), ref_tensor->GetSize(), value_tensor->GetSize()); if (value_tensor->GetSize() > 0) { GE_CHK_RT_RET(rtMemcpy(ref_tensor->MutableData(), ref_tensor->GetSize(), value_tensor->GetData(), value_tensor->GetSize(), RT_MEMCPY_HOST_TO_HOST)); @@ -51,11 +51,11 @@ Status AssignKernel::Compute(TaskContext& context) { GE_CHK_STATUS_RET(context.SetOutput(kAssignRefOutputIndex, *ref_tensor), "[%s] Failed to set output.", context.GetNodeName()); - GELOGI("AssignKernel [%s, %s] compute success.", node_->GetName().c_str(), node_->GetType().c_str()); + GELOGI("[%s] compute success.", node_->GetName().c_str()); return SUCCESS; } REGISTER_KERNEL_CREATOR(Assign, AssignKernel); -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h similarity index 79% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h index 6af30926..c3b4862b 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.h +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef GE_HYBRID_HOST_AICPU_KERNEL_ASSIGN_KERNEL_H_ -#define GE_HYBRID_HOST_AICPU_KERNEL_ASSIGN_KERNEL_H_ +#ifndef GE_HYBRID_HOST_CPU_KERNEL_ASSIGN_KERNEL_H_ +#define GE_HYBRID_HOST_CPU_KERNEL_ASSIGN_KERNEL_H_ -#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/kernel.h" namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { class AssignKernel : public Kernel { public: AssignKernel(const NodePtr &node) : Kernel(node) {} @@ -35,8 +35,8 @@ class AssignKernel : public Kernel { */ Status Compute(TaskContext &context) override; }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge -#endif // GE_HYBRID_HOST_AICPU_KERNEL_ASSIGN_KERNEL_H_ +#endif // GE_HYBRID_HOST_CPU_KERNEL_ASSIGN_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/kernel.h similarity index 84% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/kernel.h index 0e22f62a..4fe8f8a3 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/kernel.h +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef GE_HYBRID_HOST_AICPU_KERNEL_KERNEL_H_ -#define GE_HYBRID_HOST_AICPU_KERNEL_KERNEL_H_ +#ifndef GE_HYBRID_HOST_CPU_KERNEL_KERNEL_H_ +#define GE_HYBRID_HOST_CPU_KERNEL_KERNEL_H_ #include "common/ge_inner_error_codes.h" #include "graph/node.h" @@ -23,7 +23,7 @@ namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { /** * The base class for all host_kernel. */ @@ -36,8 +36,8 @@ class Kernel { protected: const NodePtr &node_; }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge -#endif // GE_HYBRID_HOST_AICPU_KERNEL_KERNEL_H_ +#endif // GE_HYBRID_HOST_CPU_KERNEL_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.cc b/src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc similarity index 76% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc index 433f8d2f..47e6e534 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.cc +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc @@ -14,20 +14,20 @@ * limitations under the License. */ -#include "hybrid/node_executor/hostaicpu/kernel/no_op_kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/no_op_kernel.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" -#include "hybrid/node_executor/hostaicpu/kernel_factory.h" +#include "hybrid/node_executor/host_cpu/kernel_factory.h" namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { Status NoOpKernel::Compute(TaskContext& context) { - GELOGI("NoOpKernel [%s, %s] no need to compute.", node_->GetName().c_str(), node_->GetType().c_str()); + GELOGI("[%s] no need to compute.", node_->GetName().c_str()); return SUCCESS; } REGISTER_KERNEL_CREATOR(NoOp, NoOpKernel); -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h similarity index 79% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h index 3c05c754..302a7e16 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.h +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef GE_HYBRID_HOST_AICPU_KERNEL_NO_OP_KERNEL_H_ -#define GE_HYBRID_HOST_AICPU_KERNEL_NO_OP_KERNEL_H_ +#ifndef GE_HYBRID_HOST_CPU_KERNEL_NO_OP_KERNEL_H_ +#define GE_HYBRID_HOST_CPU_KERNEL_NO_OP_KERNEL_H_ -#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/kernel.h" namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { class NoOpKernel : public Kernel { public: NoOpKernel(const NodePtr &node) : Kernel(node) {} @@ -35,8 +35,8 @@ class NoOpKernel : public Kernel { */ Status Compute(TaskContext &context) override; }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge -#endif // GE_HYBRID_HOST_AICPU_KERNEL_NO_OP_KERNEL_H_ +#endif // GE_HYBRID_HOST_CPU_KERNEL_NO_OP_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.cc b/src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc similarity index 53% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc index dfd8f1fe..7e87c114 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.cc +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc @@ -14,71 +14,76 @@ * limitations under the License. */ -#include "hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h" #include #include "common/fp16_t.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" -#include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" -#include "hybrid/node_executor/hostaicpu/kernel_factory.h" +#include "hybrid/node_executor/host_cpu/kernel_factory.h" + +namespace { +const char *const kAttrDtype = "dtype"; +} namespace ge { namespace hybrid { -namespace host_aicpu { -Status RandomUniformKernel::Compute(TaskContext& context) { - GELOGI("RandomUniformKernel [%s, %s] compute begin.", node_->GetName().c_str(), node_->GetType().c_str()); +namespace host_cpu { +Status RandomUniformKernel::Compute(TaskContext &context) { + GELOGI("[%s] compute begin.", node_->GetName().c_str()); + int64_t seed = 0; int64_t seed2 = 0; (void)AttrUtils::GetInt(node_->GetOpDesc(), "seed", seed); (void)AttrUtils::GetInt(node_->GetOpDesc(), "seed2", seed2); - DataType data_type = DT_UNDEFINED; - if (AttrUtils::GetDataType(node_->GetOpDesc(), VAR_ATTR_DTYPE, data_type) != GRAPH_SUCCESS) { - GELOGE(PARAM_INVALID, "get attr VAR_ATTR_DTYPE failed"); + DataType data_type = DT_FLOAT; + if (!AttrUtils::GetDataType(node_->GetOpDesc(), kAttrDtype, data_type)) { + GELOGE(PARAM_INVALID, "[%s] get attr dtype failed.", node_->GetName().c_str()); return PARAM_INVALID; } - switch (data_type) { case DT_FLOAT16: if (GenerateFP16(node_->GetOpDesc(), seed, seed2, context) != SUCCESS) { - GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_FLOAT"); + GELOGE(FAILED, "Generate random_distribution failed, data_type=DT_FLOAT"); return FAILED; } break; case DT_FLOAT: if (Generate(node_->GetOpDesc(), seed, seed2, context) != SUCCESS) { - GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_FLOAT"); + GELOGE(FAILED, "Generate random_distribution failed, data_type=DT_FLOAT"); return FAILED; } break; case DT_DOUBLE: if (Generate(node_->GetOpDesc(), seed, seed2, context) != SUCCESS) { - GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_DOUBLE"); + GELOGE(FAILED, "Generate random_distribution failed, data_type=DT_DOUBLE"); return FAILED; } break; default: - GELOGE(UNSUPPORTED, "Supported DataType for RandomUniformOp is DT_FLOAT16 / DT_FLOAT / DT_DOUBLE, but dtype=%s", + GELOGE(UNSUPPORTED, "Supported DataType is DT_FLOAT16 / DT_FLOAT / DT_DOUBLE, but data_type=%s", TypeUtils::DataTypeToSerialString(data_type).c_str()); return UNSUPPORTED; } - GELOGI("RandomUniformKernel [%s, %s] compute success.", node_->GetName().c_str(), node_->GetType().c_str()); + GELOGI("[%s] compute success.", node_->GetName().c_str()); return SUCCESS; } template -Status RandomUniformKernel::Generate(const ge::OpDescPtr& op_desc_ptr, int64_t seed, int64_t seed2, - TaskContext& context) { +Status RandomUniformKernel::Generate(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, + TaskContext &context) { GE_CHECK_NOTNULL(op_desc_ptr); // RandomUniformOp has and only has one output int64_t data_num = op_desc_ptr->GetOutputDesc(0).GetShape().GetShapeSize(); - std::unique_ptr buf(new (std::nothrow) T[data_num]()); - if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast(sizeof(T) * data_num)); - return MEMALLOC_FAILED; - } + AllocationAttr attr; + attr.SetMemType(HOST_DDR); + auto tensor_size = data_num * sizeof(T); + TensorValue tensor; + GE_CHK_STATUS_RET(context.AllocateTensor(tensor_size, tensor, &attr), "[%s] Failed to allocate output of size %zu", + context.GetNodeName(), tensor_size); + auto *buf = reinterpret_cast(tensor.MutableData()); int64_t final_seed; if (seed == 0) { if (seed2 == 0) { @@ -93,28 +98,26 @@ Status RandomUniformKernel::Generate(const ge::OpDescPtr& op_desc_ptr, int64_t s std::mt19937_64 gen(final_seed); std::uniform_real_distribution distribution(0, 1); for (int64_t i = 0; i < data_num; i++) { - *(buf.get() + i) = distribution(gen); + *(buf + i) = distribution(gen); } - std::shared_ptr output = MakeShared(buf.get(), data_num * sizeof(T)); - GE_CHECK_NOTNULL(output); - GE_CHK_STATUS_RET(context.SetOutput(0, *output), "[%s] Failed to set output.", context.GetNodeName()); - + GE_CHK_STATUS_RET(context.SetOutput(0, tensor), "[%s] Failed to set output.", context.GetNodeName()); return SUCCESS; } -Status RandomUniformKernel::GenerateFP16(const ge::OpDescPtr& op_desc_ptr, int64_t seed, int64_t seed2, - TaskContext& context) { +Status RandomUniformKernel::GenerateFP16(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, + TaskContext &context) { GE_CHECK_NOTNULL(op_desc_ptr); // RandomUniformOp has and only has one output int64_t data_num = op_desc_ptr->GetOutputDesc(0).GetShape().GetShapeSize(); - std::unique_ptr buf(new (std::nothrow) fp16_t[data_num]()); - if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "New sizeof(fp16_t) * data_num(%zu) memory failed", - static_cast(sizeof(fp16_t) * data_num)); - return MEMALLOC_FAILED; - } + AllocationAttr attr; + attr.SetMemType(HOST_DDR); + auto tensor_size = data_num * sizeof(fp16_t); + TensorValue tensor; + GE_CHK_STATUS_RET(context.AllocateTensor(tensor_size, tensor, &attr), "[%s] Failed to allocate output of size %zu", + context.GetNodeName(), tensor_size); + auto *buf = reinterpret_cast(tensor.MutableData()); int64_t final_seed; if (seed == 0) { if (seed2 == 0) { @@ -129,17 +132,14 @@ Status RandomUniformKernel::GenerateFP16(const ge::OpDescPtr& op_desc_ptr, int64 std::mt19937_64 gen(final_seed); std::uniform_real_distribution distribution(0, 1); for (int64_t i = 0; i < data_num; i++) { - *(buf.get() + i) = static_cast(distribution(gen)); + *(buf + i) = static_cast(distribution(gen)); } - std::shared_ptr output = MakeShared(buf.get(), data_num * sizeof(fp16_t)); - GE_CHECK_NOTNULL(output); - GE_CHK_STATUS_RET(context.SetOutput(0, *output), "[%s] Failed to set output.", context.GetNodeName()); - + GE_CHK_STATUS_RET(context.SetOutput(0, tensor), "[%s] Failed to set output.", context.GetNodeName()); return SUCCESS; } REGISTER_KERNEL_CREATOR(RandomUniform, RandomUniformKernel); -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h similarity index 82% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h index 343c6d08..7024b103 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.h +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef GE_HYBRID_HOST_AICPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ -#define GE_HYBRID_HOST_AICPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ +#ifndef GE_HYBRID_HOST_CPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ +#define GE_HYBRID_HOST_CPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ -#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/kernel.h" namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { class RandomUniformKernel : public Kernel { public: RandomUniformKernel(const NodePtr &node) : Kernel(node) {} @@ -41,8 +41,8 @@ class RandomUniformKernel : public Kernel { static Status GenerateFP16(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, TaskContext &context); }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge -#endif // GE_HYBRID_HOST_AICPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ +#endif // GE_HYBRID_HOST_CPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.cc b/src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc similarity index 76% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc index a8259500..db5c0f9c 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.cc +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc @@ -14,16 +14,16 @@ * limitations under the License. */ -#include "hybrid/node_executor/hostaicpu/kernel/variable_kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/variable_kernel.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" -#include "hybrid/node_executor/hostaicpu/kernel_factory.h" +#include "hybrid/node_executor/host_cpu/kernel_factory.h" namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { Status VariableKernel::Compute(TaskContext& context) { - GELOGI("VariableKernel [%s, %s] compute begin.", node_->GetName().c_str(), node_->GetType().c_str()); + GELOGI("[%s] compute begin.", node_->GetName().c_str()); auto tensor = context.GetVariable(node_->GetName()); if (tensor == nullptr) { @@ -32,12 +32,12 @@ Status VariableKernel::Compute(TaskContext& context) { } // Constant & Variable Op has and only has one output GE_CHK_STATUS_RET(context.SetOutput(0, *tensor), "[%s] Failed to set output.", context.GetNodeName()); - GELOGI("VariableKernel [%s, %s] compute success.", node_->GetName().c_str(), node_->GetType().c_str()); + GELOGI("[%s] compute success.", node_->GetName().c_str()); return SUCCESS; } REGISTER_KERNEL_CREATOR(Variable, VariableKernel); REGISTER_KERNEL_CREATOR(Constant, VariableKernel); -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h similarity index 79% rename from src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h index cb0a6834..1625e49e 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.h +++ b/src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h @@ -14,14 +14,14 @@ * limitations under the License. */ -#ifndef GE_HYBRID_HOST_AICPU_KERNEL_VARIABLE_KERNEL_H_ -#define GE_HYBRID_HOST_AICPU_KERNEL_VARIABLE_KERNEL_H_ +#ifndef GE_HYBRID_HOST_CPU_KERNEL_VARIABLE_KERNEL_H_ +#define GE_HYBRID_HOST_CPU_KERNEL_VARIABLE_KERNEL_H_ -#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/kernel.h" namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { class VariableKernel : public Kernel { public: VariableKernel(const NodePtr &node) : Kernel(node) {} @@ -35,8 +35,8 @@ class VariableKernel : public Kernel { */ Status Compute(TaskContext &context) override; }; -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge -#endif // GE_HYBRID_HOST_AICPU_KERNEL_VARIABLE_KERNEL_H_ +#endif // GE_HYBRID_HOST_CPU_KERNEL_VARIABLE_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.cc b/src/ge/hybrid/node_executor/host_cpu/kernel_factory.cc similarity index 93% rename from src/ge/hybrid/node_executor/hostaicpu/kernel_factory.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel_factory.cc index ca398500..83899fa6 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.cc +++ b/src/ge/hybrid/node_executor/host_cpu/kernel_factory.cc @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "hybrid/node_executor/hostaicpu/kernel_factory.h" +#include "hybrid/node_executor/host_cpu/kernel_factory.h" #include "framework/common/debug/ge_log.h" namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { KernelFactory &KernelFactory::Instance() { static KernelFactory instance; return instance; @@ -50,6 +50,6 @@ void KernelFactory::RegisterCreator(const std::string &type, const KERNEL_CREATO } kernel_creator_map_[type] = func; } -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.h b/src/ge/hybrid/node_executor/host_cpu/kernel_factory.h similarity index 87% rename from src/ge/hybrid/node_executor/hostaicpu/kernel_factory.h rename to src/ge/hybrid/node_executor/host_cpu/kernel_factory.h index 9ead2005..4923756b 100644 --- a/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.h +++ b/src/ge/hybrid/node_executor/host_cpu/kernel_factory.h @@ -14,22 +14,22 @@ * limitations under the License. */ -#ifndef GE_HYBRID_NODE_EXECUTOR_HOST_AICPU_KERNEL_FACTORY_H_ -#define GE_HYBRID_NODE_EXECUTOR_HOST_AICPU_KERNEL_FACTORY_H_ +#ifndef GE_HYBRID_NODE_EXECUTOR_HOST_CPU_KERNEL_FACTORY_H_ +#define GE_HYBRID_NODE_EXECUTOR_HOST_CPU_KERNEL_FACTORY_H_ #include #include #include #include "common/ge/ge_util.h" -#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" +#include "hybrid/node_executor/host_cpu/kernel/kernel.h" namespace ge { namespace hybrid { -namespace host_aicpu { +namespace host_cpu { using KERNEL_CREATOR_FUNC = std::function(const NodePtr &)>; /** - * manage all the host_aicpu_kernel, support create kernel. + * manage all the host_cpu_kernel, support create kernel. */ class KernelFactory { public: @@ -79,8 +79,8 @@ class KernelRegistrar { #define REGISTER_KERNEL_CREATOR(type, clazz) \ std::shared_ptr Creator_##type##Kernel(const NodePtr &node) { return MakeShared(node); } \ KernelRegistrar g_##type##Kernel_creator(#type, Creator_##type##Kernel) -} // namespace host_aicpu +} // namespace host_cpu } // namespace hybrid } // namespace ge -#endif // GE_HYBRID_NODE_EXECUTOR_HOST_AICPU_KERNEL_FACTORY_H_ +#endif // GE_HYBRID_NODE_EXECUTOR_HOST_CPU_KERNEL_FACTORY_H_ diff --git a/src/ge/hybrid/node_executor/node_executor.cc b/src/ge/hybrid/node_executor/node_executor.cc index 0f4c5494..8de15ea0 100644 --- a/src/ge/hybrid/node_executor/node_executor.cc +++ b/src/ge/hybrid/node_executor/node_executor.cc @@ -27,6 +27,8 @@ const char *const kEngineNameAiCore = "AIcoreEngine"; const char *const kEngineNameGeLocal = "DNN_VM_GE_LOCAL_OP_STORE"; const char *const kEngineNameAiCpu = "aicpu_kernel"; const char *const kEngineNameHccl = "ops_kernel_info_hccl"; +const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; +const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; } // namespace Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); @@ -61,6 +63,8 @@ Status NodeExecutorManager::EnsureInitialized() { engine_mapping_.emplace(kEngineNameGeLocal, NodeExecutorManager::ExecutorType::GE_LOCAL); engine_mapping_.emplace(kEngineNameAiCpu, NodeExecutorManager::ExecutorType::AICPU_TF); engine_mapping_.emplace(kEngineNameHccl, NodeExecutorManager::ExecutorType::HCCL); + engine_mapping_.emplace(kEngineNameRts, NodeExecutorManager::ExecutorType::RTS); + engine_mapping_.emplace(kEngineNameHostCpu, NodeExecutorManager::ExecutorType::HOST_CPU); std::shared_ptr instance_ptr = GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { @@ -135,6 +139,10 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { GELOGD("[%s] Skipping CalcOpRunningParam for PartitionedCall.", node.GetName().c_str()); return SUCCESS; } + for (size_t i = 0; i < node.GetOpDesc()->GetOutputsSize(); ++i) { + GeTensorDescPtr output_tensor = op_desc->MutableOutputDesc(static_cast(i)); + TensorUtils::SetSize(*(output_tensor.get()), 0); + } auto it = kernel_stores_.find(op_desc->GetOpKernelLibName()); if (it == kernel_stores_.end()) { diff --git a/src/ge/hybrid/node_executor/node_executor.h b/src/ge/hybrid/node_executor/node_executor.h index 23e52428..79726b09 100644 --- a/src/ge/hybrid/node_executor/node_executor.h +++ b/src/ge/hybrid/node_executor/node_executor.h @@ -134,6 +134,8 @@ class NodeExecutorManager { GE_LOCAL, CONTROL_OP, HCCL, + RTS, + HOST_CPU, RESERVED }; diff --git a/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc index cda9a275..4c9cf7bf 100644 --- a/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc +++ b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc @@ -74,7 +74,9 @@ Status PartitionedCallNodeExecutor::LoadTask(const ge::hybrid::HybridModel &mode } Status PartitionedCallNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[PartitionedCallPrepareTask] Start"); GE_CHK_STATUS_RET(task.Init(context), "[%s] Failed to init task.", context.GetNodeName()); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[PartitionedCallPrepareTask] End"); return SUCCESS; } } // namespace hybrid diff --git a/src/ge/hybrid/node_executor/rts/rts_node_executor.cc b/src/ge/hybrid/node_executor/rts/rts_node_executor.cc new file mode 100644 index 00000000..51241e55 --- /dev/null +++ b/src/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rts_node_executor.h" +#include "common/debug/log.h" +#include "common/ge/ge_util.h" +#include "graph/utils/tensor_utils.h" +#include "runtime/rt.h" + +namespace ge { +namespace hybrid { +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::RTS, RtsNodeExecutor); + +Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { + auto input_desc = context.MutableInputDesc(index); + GE_CHECK_NOTNULL(input_desc); + int64_t copy_size = 0; + GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); + // copy_size would not be negative since GetTensorSizeInBytes returned successfully. + if (copy_size != 0) { + GELOGD("[%s] index = %d, copy size = %ld", context.GetNodeName(), index, copy_size); + auto input = context.MutableInput(index); + auto output = context.MutableOutput(index); + GE_CHECK_NOTNULL(input); + GE_CHECK_NOTNULL(output); + GE_CHK_RT_RET(rtMemcpyAsync(output->MutableData(), output->GetSize(), input->GetData(), copy_size, + RT_MEMCPY_DEVICE_TO_DEVICE, context.GetStream())); + } else { + GELOGW("[%s] index = %d, copy size = 0", context.GetNodeName(), index); + } + + return SUCCESS; +} + +Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGD("[%s] Start to execute.", context.GetNodeName()); + GE_CHK_STATUS_RET(DoCopyTensor(context, 0)); + + if (done_callback) { + GE_CHK_STATUS_RET(context.RegisterCallback(done_callback)); + } + + GELOGD("[%s] Done executing successfully.", context.GetNodeName()); + return SUCCESS; +} + +Status IdentityNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } + +Status IdentityNNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGD("[%s] Start to execute.", context.GetNodeName()); + for (int i = 0; i < context.NumInputs(); ++i) { + GE_CHK_STATUS_RET(DoCopyTensor(context, i)); + } + + if (done_callback) { + GE_CHK_STATUS_RET(context.RegisterCallback(done_callback)); + } + + GELOGD("[%s] Done executing successfully.", context.GetNodeName()); + return SUCCESS; +} + +Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { + auto op_type = node->GetType(); + if (op_type == IDENTITY) { + task = MakeShared(); + } else if (op_type == IDENTITYN) { + task = MakeShared(); + } else { + GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str()); + return INTERNAL_ERROR; + } + + GE_CHECK_NOTNULL(task); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/rts/rts_node_executor.h b/src/ge/hybrid/node_executor/rts/rts_node_executor.h new file mode 100644 index 00000000..9da28966 --- /dev/null +++ b/src/ge/hybrid/node_executor/rts/rts_node_executor.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_EXECUTOR_H_ +#define GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_EXECUTOR_H_ + +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +class IdentityNodeTask : public NodeTask { + public: + Status UpdateArgs(TaskContext &context) override; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + + protected: + static Status DoCopyTensor(TaskContext &context, int index); +}; + +class IdentityNNodeTask : public IdentityNodeTask { + public: + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; +}; + +class RtsNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; +}; +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/task_context.cc b/src/ge/hybrid/node_executor/task_context.cc index ee35bffa..dd833fe1 100644 --- a/src/ge/hybrid/node_executor/task_context.cc +++ b/src/ge/hybrid/node_executor/task_context.cc @@ -210,20 +210,6 @@ Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, T } } - // Temp modification - if (node_item_->node_type == "UnsortedSegmentSum" || node_item_->node_type == "UnsortedSegmentSumD" || - node_item_->node_type == "ScatterNd") { - auto &out_tensor = outputs_start_[index]; - GELOGD("[%s] clear output tensor: %s", GetNodeName(), out_tensor.DebugString().c_str()); - auto *ctx = GetExecutionContext(); - string name = "rtMemsetAsync" + node_item_->node_name; - RegisterCallback([ctx, name]() { RECORD_CALLBACK_EVENT(ctx, name.c_str(), "[Compute] Start"); }); - RECORD_EXECUTION_EVENT(GetExecutionContext(), node_item_->node_name.c_str(), "[rtMemsetAsync] Start"); - GE_CHK_RT_RET(rtMemsetAsync(out_tensor.MutableData(), out_tensor.GetSize(), 0, out_tensor.GetSize(), GetStream())); - RECORD_EXECUTION_EVENT(GetExecutionContext(), node_item_->node_name.c_str(), "[rtMemsetAsync] End"); - RegisterCallback([ctx, name]() { RECORD_CALLBACK_EVENT(ctx, name.c_str(), "[Compute] End"); }); - } - if (execution_context_->trace_enabled) { outputs_start_[index].SetName(node_item_->NodeName() + "_out_" + std::to_string(index)); } @@ -245,8 +231,8 @@ Status TaskContext::AllocateOutputs(AllocationAttr *attr) { return SUCCESS; } -Status TaskContext::AllocateTemp(size_t size, TensorValue &tensor) { - auto buffer = TensorBuffer::Create(execution_context_->allocator, size); +Status TaskContext::AllocateTensor(size_t size, TensorValue &tensor, AllocationAttr *attr) { + auto buffer = TensorBuffer::Create(execution_context_->allocator, size, attr); if (buffer == nullptr) { GELOGE(MEMALLOC_FAILED, "Failed to allocate buffer of size: %zu", size); return MEMALLOC_FAILED; @@ -275,7 +261,12 @@ int64_t TaskContext::GetSessionId() const { return execution_context_->session_i Status TaskContext::GetStatus() const { return status_; } -void TaskContext::SetStatus(Status status) { status_ = status; } +void TaskContext::SetStatus(Status status) { + status_ = status; + if (status != SUCCESS) { + execution_context_->SetErrorCode(status); + } +} Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr) { GE_CHECK_NOTNULL(buffer); @@ -322,7 +313,8 @@ Status TaskContext::PropagateOutputs() { subgraph_context_->all_inputs_[input_offset] = *tensor; if (execution_context_->trace_enabled) { - subgraph_context_->all_inputs_[input_offset].SetName(node_item_->NodeName() + "_in_" + std::to_string(i)); + subgraph_context_->all_inputs_[input_offset].SetName(node_item_->NodeName() + "_in_" + + std::to_string(dst_input_idx)); } } } @@ -364,7 +356,10 @@ void TaskContext::SetForceInferShape(bool force_infer_shape) { force_infer_shape void TaskContext::NodeDone() { subgraph_context_->NodeDone(node_item_->node); } -void TaskContext::OnError(Status error) { subgraph_context_->OnError(error); } +void TaskContext::OnError(Status error) { + subgraph_context_->OnError(error); + execution_context_->SetErrorCode(error); +} bool TaskContext::IsTraceEnabled() const { return execution_context_->trace_enabled; } @@ -373,5 +368,19 @@ TensorValue *TaskContext::GetVariable(const std::string &name) { return executio uint64_t TaskContext::GetIterationNumber() const { return iteration_; } bool TaskContext::IsDumpEnabled() const { return execution_context_->dump_enabled; } + +Status TaskContext::TryExecuteCallback(const function &callback_fun) const { + if (!callback_fun) { + return SUCCESS; + } + + if (node_item_->has_observer) { + return RegisterCallback(callback_fun); + } + + callback_fun(); + return SUCCESS; +} +const DumpProperties &TaskContext::GetDumpProperties() const { return execution_context_->dump_properties; } } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/task_context.h b/src/ge/hybrid/node_executor/task_context.h index 5c42a347..ed45116d 100644 --- a/src/ge/hybrid/node_executor/task_context.h +++ b/src/ge/hybrid/node_executor/task_context.h @@ -20,6 +20,7 @@ #include #include #include +#include "common/properties_manager.h" #include "external/ge/ge_api_error_codes.h" #include "hybrid/common/tensor_value.h" #include "hybrid/common/npu_memory_allocator.h" @@ -71,13 +72,16 @@ class TaskContext { bool IsDumpEnabled() const; + const DumpProperties &GetDumpProperties() const; + const GraphExecutionContext *GetExecutionContext() { return execution_context_; } - Status AllocateTemp(size_t size, TensorValue &tensor); + Status AllocateTensor(size_t size, TensorValue &tensor, AllocationAttr *attr = nullptr); void *MutableWorkspace(int index); const void *GetVarBaseAddr(); Status RegisterCallback(const std::function &callback_fun) const; + Status TryExecuteCallback(const std::function &callback_fun) const; Status PropagateOutputs(); diff --git a/src/ge/init/gelib.cc b/src/ge/init/gelib.cc index f7740a3c..0532321e 100644 --- a/src/ge/init/gelib.cc +++ b/src/ge/init/gelib.cc @@ -47,6 +47,8 @@ namespace ge { namespace { const int kDecimal = 10; const int kSocVersionLen = 50; +const int kDefaultDeviceIdForTrain = 0; +const int kDefaultDeviceIdForInfer = -1; const uint32_t kAicoreOverflow = (0x1 << 0); const uint32_t kAtomicOverflow = (0x1 << 1); const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); @@ -157,8 +159,12 @@ Status GELib::SystemInitialize(const map &options) { // In train and infer, profiling is always needed. InitOptions(options); InitProfiling(this->options_); - - if (is_train_mode_) { + // 1.`is_train_mode_` means case: train + // 2.`(!is_train_mode_) && (options_.device_id != kDefaultDeviceIdForInfer)` means case: online infer + // these two case need call `InitSystemWithOptions->rtGetDeviceIndexByPhyId` + // to convert phy device id to logical device id + // note:rtGetDeviceIndexByPhyId return `0` logical id when input phy device id is `0` + if (is_train_mode_ || (options_.device_id != kDefaultDeviceIdForInfer)) { status = InitSystemWithOptions(this->options_); } else { status = InitSystemWithoutOptions(); @@ -200,7 +206,7 @@ void GELib::InitOptions(const map &options) { if (iter != options.end()) { this->options_.session_id = std::strtoll(iter->second.c_str(), nullptr, kDecimal); } - this->options_.device_id = 0; + this->options_.device_id = is_train_mode_ ? kDefaultDeviceIdForTrain : kDefaultDeviceIdForInfer; iter = options.find(OPTION_EXEC_DEVICE_ID); if (iter != options.end()) { this->options_.device_id = static_cast(std::strtol(iter->second.c_str(), nullptr, kDecimal)); @@ -252,7 +258,8 @@ void GELib::InitOptions(const map &options) { } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status GELib::InitSystemWithOptions(Options &options) { - GELOGI("Training init GELib. session Id:%ld, device id :%d ", options.session_id, options.device_id); + std::string mode = is_train_mode_ ? "Training" : "Online infer"; + GELOGI("%s init GELib. session Id:%ld, device id :%d ", mode.c_str(), options.session_id, options.device_id); GEEVENT("System init with options begin, job id %s", options.job_id.c_str()); std::lock_guard lock(status_mutex_); GE_IF_BOOL_EXEC(is_system_inited && !is_shutdown, @@ -292,14 +299,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status GELib::InitSystemWithOpt is_system_inited = true; is_shutdown = false; - GELOGI("Training init GELib success."); + GELOGI("%s init GELib success.", mode.c_str()); return SUCCESS; } Status GELib::SystemShutdownWithOptions(const Options &options) { - GELOGI("Training finalize GELib begin."); - + std::string mode = is_train_mode_ ? "Training" : "Online infer"; + GELOGI("%s finalize GELib begin.", mode.c_str()); std::lock_guard lock(status_mutex_); GE_IF_BOOL_EXEC(is_shutdown || !is_system_inited, GELOGW("System Shutdown with options is already is_shutdown or system does not inited. " @@ -315,9 +322,7 @@ Status GELib::SystemShutdownWithOptions(const Options &options) { is_system_inited = false; is_shutdown = true; - - GELOGI("Training finalize GELib success."); - + GELOGI("%s finalize GELib success.", mode.c_str()); return SUCCESS; } @@ -387,7 +392,7 @@ Status GELib::Finalize() { // Shut down profiling ShutDownProfiling(); - if (is_train_mode_) { + if (is_train_mode_ || (options_.device_id != kDefaultDeviceIdForInfer)) { GELOGI("System ShutDown."); mid_state = SystemShutdownWithOptions(this->options_); if (mid_state != SUCCESS) { diff --git a/src/ge/ir_build/atc_ir_common.cc b/src/ge/ir_build/atc_ir_common.cc index 91fa17d4..dbfe688b 100644 --- a/src/ge/ir_build/atc_ir_common.cc +++ b/src/ge/ir_build/atc_ir_common.cc @@ -16,7 +16,6 @@ #include "atc_ir_common.h" #include "common/util/error_manager/error_manager.h" -#include "common/model_parser/graph_parser_util.h" #include "external/ge/ge_api_types.h" #include "framework/common/string_util.h" #include "framework/common/types.h" @@ -43,10 +42,25 @@ const std::set kBufferOptimizeSupportOption = {"l1_optimize", "l2_o const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision"; +const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; +const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; +const char *const kSplitError1 = "size not equal to 2 split by \":\""; +const char *const kEmptyError = "can not be empty"; +const char *const kFloatNumError = "exist float number"; +const char *const kDigitError = "is not digit"; const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; const char *const kSelectImplmodeError = "only support high_performance, high_precision"; const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; +vector SplitInputShape(const std::string &input_shape) { + vector shape_pair_vec; + size_t pos = input_shape.rfind(":"); + if (pos != std::string::npos) { + shape_pair_vec.emplace_back(input_shape.substr(0, pos)); + shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); + } + return shape_pair_vec; +} } // namespace bool CheckDynamicBatchSizeInputShapeValid(unordered_map> shape_map, @@ -182,15 +196,7 @@ bool CheckDynamicDimsInputShapeValid(const unordered_map GELOGE(ge::PARAM_INVALID, "Dim num must within [%zu, %zu] when set dynamic_dims.", kMinNDDimNum, kMaxNDDimNum); return false; } - int tmp = std::count(shapes.begin(), shapes.end(), kDynamicInputDim); - if (dynamic_dim != 0 && dynamic_dim != tmp) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, - {"--input_shape's -1 num", std::to_string(tmp), "Every set's num of -1 must be same"}); - GELOGE(ge::PARAM_INVALID, "input_shape's shape is invalid, every set's num of -1 must be same."); - return false; - } - dynamic_dim = tmp; + dynamic_dim += std::count(shapes.begin(), shapes.end(), kDynamicInputDim); } if (dynamic_dim == 0) { ErrorManager::GetInstance().ATCReportErrMessage( @@ -229,10 +235,11 @@ bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims vector one_set = StringUtils::Split(split_dim, ','); if (one_set.size() != static_cast(dynamic_dim_num)) { ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, - {"--dynamic_dims's parameter num of each set", std::to_string(one_set.size()), - "must be same as input_shape's num of -1"}); - GELOGE(ge::PARAM_INVALID, "dynamic_dims's parameter num of each set must be same as input_shape's num of -1."); + "E10042", {"parameter", "reason"}, + {"dynamic_dims", "Each gear setting needs to be consistent with the number of -1 in the inputshape"}); + GELOGE(ge::PARAM_INVALID, + "Input parameter --dynamic_dims parse failed, " + "reason: Each gear setting needs to be consistent with the number of -1 in the inputshape."); return false; } for (auto dim : one_set) { @@ -303,6 +310,90 @@ Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_i return ge::SUCCESS; } +bool ParseInputShape(const string &input_shape, unordered_map> &shape_map, + vector>> &user_shape_map, bool is_dynamic_input) { + vector shape_vec = StringUtils::Split(input_shape, ';'); + const int DEFAULT_SHAPE_PAIR_SIZE = 2; + for (const auto &shape : shape_vec) { + vector shape_pair_vec = SplitInputShape(shape); + if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kSplitError1, kInputShapeSample1}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kSplitError1, kInputShapeSample1); + return false; + } + if (shape_pair_vec[1].empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kEmptyError, kInputShapeSample1}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kEmptyError, kInputShapeSample1); + return false; + } + + vector shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); + vector shape_values; + for (auto &shape_value_str : shape_value_strs) { + // stoul: The method may throw an exception: invalid_argument/out_of_range + if (std::string::npos != shape_value_str.find('.')) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kFloatNumError, kInputShapeSample2}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kFloatNumError, kInputShapeSample2); + return false; + } + + long left_result = 0; + try { + left_result = stol(StringUtils::Trim(shape_value_str)); + if (!shape_value_str.empty() && (shape_value_str.front() == '-')) { + // The value maybe dynamic shape [-1], need substr it and verify isdigit. + shape_value_str = shape_value_str.substr(1); + } + for (char c : shape_value_str) { + if (!isdigit(c)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kDigitError, kInputShapeSample2}); + GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str()); + return false; + } + } + } catch (const std::out_of_range &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, + {"--input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]Ă¢â‚¬â„¢s value[%s] cause out of range execption!", shape_value_str.c_str()); + return false; + } catch (const std::invalid_argument &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, + {"--input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]Ă¢â‚¬â„¢s value[%s] cause invalid argument!", shape_value_str.c_str()); + return false; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, + {"--input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]Ă¢â‚¬â„¢s value[%s] cause unkown execption!", shape_value_str.c_str()); + return false; + } + int64_t result = left_result; + // - 1 is not currently supported + if (!is_dynamic_input && result <= 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)}); + GELOGW( + "Input parameter[--input_shape]Ă¢â‚¬â„¢s shape value[%s] is invalid, " + "expect positive integer, but value is %ld.", + shape.c_str(), result); + return false; + } + shape_values.push_back(result); + } + + shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + } + + return true; +} + Status CheckOutputTypeParamValid(const std::string output_type) { if ((!output_type.empty()) && (kOutputTypeSupportDatatype.find(output_type) == kOutputTypeSupportDatatype.end())) { ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, diff --git a/src/ge/ir_build/atc_ir_common.h b/src/ge/ir_build/atc_ir_common.h index e4d3103b..53143c2b 100644 --- a/src/ge/ir_build/atc_ir_common.h +++ b/src/ge/ir_build/atc_ir_common.h @@ -58,6 +58,9 @@ Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string std::string &dynamic_dims, const std::string input_shape, const std::string input_format, bool &is_dynamic_input); +bool ParseInputShape(const std::string &input_shape, std::unordered_map> &shape_map, + std::vector>> &user_shape_map, bool is_dynamic_input = false); + Status CheckOutputTypeParamValid(const std::string output_type); Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf); diff --git a/src/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc index a9ff1ab5..0a60fa11 100644 --- a/src/ge/ir_build/ge_ir_build.cc +++ b/src/ge/ir_build/ge_ir_build.cc @@ -26,7 +26,6 @@ #include "framework/common/util.h" #include "framework/omg/omg_inner_types.h" #include "framework/omg/omg_inner_types.h" -#include "common/model_parser/graph_parser_util.h" #include "ge/ge_api_types.h" #include "generator/ge_generator.h" #include "graph/compute_graph.h" @@ -184,7 +183,7 @@ graphStatus Impl::Init(const std::map &options) { // 1. check options graphStatus ret = CheckOptions(options); if (ret != GRAPH_SUCCESS) { - GELOGE(ret, "user input options are illegal! Please check!"); + GELOGE(ret, "User input options are illegal! Please check!"); return ret; } // set log level @@ -255,10 +254,10 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vectorGetType() == DATA) { (void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++); - GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); + GELOGI("Data op inputDesc size: %zu", op->GetAllInputsDesc().size()); ge::GeTensorDesc tensor = op->GetInputDesc(0); string data_op_name = op->GetName(); - GELOGI("Data op name is: %s", data_op_name.c_str()); + GELOGI("Data op name: %s", data_op_name.c_str()); ge::GeShape data_shape; auto iter = GetContext().input_dims.find(data_op_name); if (iter != GetContext().input_dims.end()) { @@ -279,7 +278,7 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector &options, @@ -287,7 +286,7 @@ graphStatus Impl::BuildModel(const Graph &graph, const std::map(model.data.get()), static_cast(model.length)); } + +graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version) { + GELOGD("Enter aclgrphGetIRVersion process!"); + GE_CHECK_NOTNULL(major_version); + GE_CHECK_NOTNULL(minor_version); + GE_CHECK_NOTNULL(patch_version); + *major_version = IR_MAJOR_VERSION; + *minor_version = IR_MINOR_VERSION; + *patch_version = IR_PATCH_VERSION; + return GRAPH_SUCCESS; +} } // namespace ge diff --git a/src/ge/offline/main.cc b/src/ge/offline/main.cc deleted file mode 100644 index 214e495a..00000000 --- a/src/ge/offline/main.cc +++ /dev/null @@ -1,1227 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "common/gflags_util.h" -#include "common/util.h" -#include "common/util/error_manager/error_manager.h" -#include "common/model_parser/graph_parser_util.h" -#include "framework/common/debug/ge_log.h" -#include "ge/ge_api.h" -#include "generator/ge_generator.h" -#include "graph/anchor.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/graph.h" -#include "graph/op_desc.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/type_utils.h" -#include "init/gelib.h" -#include "ir_build/atc_ir_common.h" -#include "omg/omg.h" -#include "omg/parser/parser_factory.h" -#include "omg/parser/parser_inner_ctx.h" -#include "parser/common/register_tbe.h" -#include "register/op_registry.h" -#include "single_op_parser.h" - -using domi::BuildMode; -using domi::GetContext; -using domi::OpRegistrationData; -using domi::OpRegistry; -using domi::Status; -using domi::SUCCESS; -using ge::GEN_OM_MODEL; -using ge::GflagsUtils; -using ge::MODEL_TO_JSON; -using ge::ONLY_PRE_CHECK; -using ge::ParseInputShape; -using ge::PBTXT_TO_JSON; -using std::map; -using std::pair; -using std::shared_ptr; -using std::string; -using std::vector; - -static bool is_dynamic_input = false; - -// 310 limited 8G size -const char *const kGraphMemoryManagerMallocMaxSize = "8*1024*1024*1024"; -const char *const kModeSupport = - "only support 0(model to framework model), " - "1(framework model to json), 3(only pre-check), 5(pbtxt to json)"; -const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow)"; - -DEFINE_string(model, "", "The model file."); -DEFINE_string(output, "", "The output file path&name."); -DEFINE_int32(framework, -1, "Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow)."); -DEFINE_string(weight, "", "Optional; weight file. Required when framework is Caffe."); - -DEFINE_string(input_shape, "", - "Optional; shape of input data. Required when framework is caffe " - "or TensorFLow or MindSpore." - "Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\""); -DEFINE_bool(h, false, "show this help message"); -DEFINE_string(cal_conf, "", "Optional; the calibration config file."); - -DEFINE_string(insert_op_conf, "", "Optional; the config file to insert new op, for example AIPP op."); -DEFINE_string(op_name_map, "", "Optional; custom op name mapping file."); - -DEFINE_string(target, "", "Optional; mini."); - -DEFINE_string(om, "", "The model file to be converted to json."); -DEFINE_string(json, "", "The output json file path&name which is converted from a model."); -DEFINE_int32(mode, 0, - "Optional; run mode, 0(default): model => framework model; 1: " - "framework model => json; 3: only pre-check; 5: pbtxt => json."); - -#if !defined(__ANDROID__) && !defined(ANDROID) -DEFINE_int32(encrypt_mode, -1, "Optional; the encrypt flag. 0: encrypt; -1(default): not encrypt"); -DEFINE_string(encrypt_key, "", "Optional; the encrypt_key file."); -DEFINE_string(certificate, "", "Optional; the certificate file."); -DEFINE_string(hardware_key, "", "Optional; the ISV key file."); -DEFINE_string(private_key, "", "Optional; the private key file."); -#endif - -DEFINE_string(out_nodes, "", - "Optional; output nodes designated by users." - "Format: \"node_name1:0;node_name1:1;node_name2:0\""); - -DEFINE_string(precision_mode, "", - "Optional; precision mode." - "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); - -DEFINE_string(input_format, "", - "Optional; input_format, format of input data, NCHW;NHWC." - "Format:\"NHWC\""); - -DEFINE_string(check_report, "check_result.json", "Optional; the pre-checking report file."); - -DEFINE_string(input_fp16_nodes, "", - "Optional; input node datatype is fp16 and format is NC1HWC0." - "Format:\"node_name1;node_name2\""); - -DEFINE_string(is_output_adjust_hw_layout, "", - "Optional; Net output node's datatype is fp16 and format is " - "NC1HWC0, or not." - "Format:\"false,true,false,true\""); - -DEFINE_string(is_input_adjust_hw_layout, "", - "Optional; Intput node's datatype is fp16 and format is " - "NC1HWC0, or not." - "Format:\"false,true,false,true\""); - -DEFINE_string(output_type, "", - "Optional; output type! " - "Support FP32,FP16,INT8,INT16,UINT16,UINT8,INT32,INT64,UINT32,UINT64,DOUBLE."); - -DEFINE_string(op_select_implmode, "", - "Optional; op select implmode! " - "Support high_precision, high_performance."); - -DEFINE_string(optypelist_for_implmode, "", - "Optional; Nodes need use implmode selected in op_select_implmode " - "Format:\"node_name1,node_name2\""); - -DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file."); - -DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if."); - -DEFINE_string(auto_tune_mode, "", "Optional; Set tune mode."); - -DEFINE_string(soc_version, "", "The soc version."); - -DEFINE_string(core_type, "AiCore", "Optional; If set to VectorCore, only use vector core."); - -DEFINE_string(aicore_num, "", "Optional; Set aicore num"); - -DEFINE_string(buffer_optimize, "l2_optimize", "Optional; buffer optimize"); - -DEFINE_string(fusion_switch_file, "", "Optional; Set fusion switch file path"); - -DEFINE_string(save_original_model, "", "Optional; enable output original offline model. false(default)"); - -DEFINE_string(dynamic_batch_size, "", - "Optional; If set, generate dynamic multi batch model. " - "Different batch sizes are split by ','." - "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); - -DEFINE_string(dynamic_image_size, "", - "Optional; If set, generate dynamic multi image size model." - "Different groups of image size are split by ';'," - "while different dimensions of each group are split by ','." - "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); - -DEFINE_string(dynamic_dims, "", - "Optional; If set, generate dynamic input size model. " - "Different groups of size are split by ';', while different dimensions of each group are split by ','." - "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); - -DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled."); - -DEFINE_string(enable_compress_weight, "false", - "Optional; enable compress weight. true: enable; false(default): disable"); - -DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight"); - -DEFINE_string(enable_single_stream, "", "Optional; enable single stream. true: enable; false(default): disable"); - -DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, warning, error, null"); - -DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); - -DEFINE_int32(op_debug_level, 0, - "Optional; configure debug level of compiler. 0(default): close debug;" - "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); - -class GFlagUtils { - public: - /** - * @name InitGFlag - * @brief initialize gflag - * @return void - */ - static void InitGFlag(int argc, char *argv[]) { - // -help - gflags::SetUsageMessage( - "usage: ./atc \n" - "generate offline model example:\n" - "./atc --model=./alexnet.prototxt --weight=./alexnet.caffemodel \n" - "--framework=0 --output=./domi \n" - "generate offline model for single op example:\n" - "./atc --singleop=./op_list.json --output=./op_model \n" - "arguments explain:\n" - " --model Model file\n" - " --singleop Single op definition file. atc will generate offline " - "model(s) for single op if --singleop is set.\n" - " --weight Weight file. Required when framework is Caffe\n" - " --framework Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow)\n" - " --output Output file path&name(needn't suffix, will add " - ".om automatically). \n" - " If --singleop is set, this arg specifies the directory to " - "which the single op offline model will be generated\n" - " --input_shape Shape of input data. Separate multiple nodes with semicolons (;)." - "Use double quotation marks (\") to enclose each argument." - "E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" - " --h/help Show this help message\n" - " --log Generate atc log. Support debug, info, warning, error, null\n" - " --insert_op_conf Config file to insert new op\n" - " --op_name_map Custom op name mapping file\n" - " Note: A semicolon(;) cannot be included in each " - "path, otherwise the resolved path will not match the expected one.\n" - " --precision_mode precision mode, support force_fp16, allow_mix_precision, " - "allow_fp32_to_fp16, must_keep_origin_dtype.\n" - " --om The model file to be converted to json\n" - " --json The output json file path&name which is " - "converted from a model\n" - " --mode Run mode. 0(default): model => framework model 1: " - "framework model => json; 3: only pre-check; 5: pbtxt => json\n" - " --dump_mode The switch of dump json with shape, to be used with mode 1.Default value is : 0." - "0 means disable, 1 means enable .\n" - " --input_format Format of input data. E.g.: \"NCHW\"\n" - " --check_report The pre-checking report file. Default value is: " - "\"check_result.json\"\n" - " --disable_reuse_memory The switch of reuse memory. Default value is : 0." - "0 means reuse memory, 1 means do not reuse memory.\n" - " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons " - "(;)." - "Use double quotation marks (\") to enclose each argument." - "E.g.: \"node_name1;node_name2\"\n" - " --is_input_adjust_hw_layout Intput node datatype is fp16 and format is " - "NC1HWC0, used with input_fp16_nodes E.g.: \"true,true,false,true\"\n" - " --out_nodes Output nodes designated by users. Separate multiple nodes with semicolons (;)." - "Use double quotation marks (\") to enclose each argument." - "E.g.: \"node_name1:0;node_name1:1;node_name2:0\"\n" - " --is_output_adjust_hw_layout Net output node datatype is fp16 and format is " - "NC1HWC0, used with out_nodes. E.g.: \"true,true,false,true\"\n" - " --output_type Set net output type. Support FP32, FP16, UINT8." - "E.g.: FP16, means all out nodes set datatype FP16." - "\"node_name1:0:FP16;node_name2:1:FP32\", means multiple out nodes set corresponding datatype.\n" - " --op_select_implmode Set op select implmode. Support high_precision, high_performance." - "default: high_performance\n" - "disable\n" - " --optypelist_for_implmode Appoint which op to use op_select_implmode, used with op_select_implmode ." - "Separate multiple nodes with commas (,). Use double quotation marks (\") to enclose each argument." - "E.g.: \"node_name1,node_name2\"\n" - " --soc_version The soc version.\n" - " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. " - "Default value is: AiCore\n" - " --enable_compress_weight Enable compress weight. true: enable; false(default): disable\n" - " --compress_weight_conf Config file to compress weight\n" - " --aicore_num Set aicore num\n" - " --buffer_optimize Set buffer optimize. default enabled, set \"off_optimize\" to close \n" - " --enable_small_channel Set enable small channel. 0(default): disable; 1: enable\n" - " --fusion_switch_file Set fusion switch file path\n" - " --save_original_model Control whether to output original model. " - "E.g.: true: output original model\"\n" - " --dynamic_batch_size Set dynamic batch size. E.g: \"batchsize1,batchsize2,batchsize3\"\n" - " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;)." - "Use double quotation marks (\") to enclose each argument." - "E.g: \"imagesize1_height,imagesize1_width;imagesize2_height,imagesize2_width\"\n" - " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" - " --enable_single_stream Enable single stream. true: enable; false(default): disable\n"); - - gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); - // Using gflags to analyze input parameters - GflagsUtils::ChangeHelpFlags(FLAGS_h); - gflags::HandleCommandLineHelpFlags(); - } - - static Status CheckDumpInfershapeJsonFlags() { - Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), - return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!", - FLAGS_weight.c_str()); - return domi::SUCCESS; - } - - static Status CheckFlags() { - // No model file information passed in - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_model == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"}); - return domi::PARAM_INVALID, "Input parameter[--model]'s value is empty!"); - // check param disable_reuse_memory - GE_CHK_BOOL_EXEC(ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) == ge::SUCCESS, - return ge::FAILED, "check disable_reuse_memory failed!"); - - // check optypelist_for_implmode and op_select_implmode - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, - return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); - // No output file information passed in - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"}); - return domi::PARAM_INVALID, "Input parameter[--output]'s value is empty!"); - - Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "CheckFrameWorkValid failed"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, FLAGS_dynamic_dims, - FLAGS_input_shape, FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, - return ge::FAILED, "check dynamic size(batch size, image size or dims) failed!"); - -#if !defined(__ANDROID__) && !defined(ANDROID) - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!CheckEncryptModeValid(FLAGS_encrypt_mode), return domi::FAILED, - "encrypt_mode %d not valid!!", FLAGS_encrypt_mode); - - if (FLAGS_encrypt_mode == 0) { // Encryption mode - GELOGI("domi will run with encrypt!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), return domi::FAILED, - "encrypt_key file not found!!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), return domi::FAILED, - "certificate file not found!!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), return domi::FAILED, - "hardware_key file not found!!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), return domi::FAILED, - "private_key file not found!!"); - } else { // No encryption - GELOGI("domi will run without encrypt!"); - } -#endif - - /** - * Check the validity of the I / O file path - */ - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_model, "--model"), return domi::FAILED, - "model file %s not found!!", FLAGS_model.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), - return domi::FAILED, "weight file %s not found!!", FLAGS_weight.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"), - return domi::FAILED, "calibration config file %s not found!!", - FLAGS_cal_conf.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"), return domi::FAILED, - "op config file %s not found!!", FLAGS_op_name_map.c_str()); - - GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS, - return ge::FAILED, "check insert op conf failed!"); - - GE_CHK_BOOL_EXEC( - ge::CheckCompressWeightParamValid(FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, - return ge::FAILED, "check compress weight failed!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), return domi::FAILED, - "check_report file %s not found!!", FLAGS_check_report.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_mode == GEN_OM_MODEL && (!ge::CheckOutputPathValid(FLAGS_output, "--output") || - !CheckPathWithName(FLAGS_output)), - return domi::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_save_original_model != "" && FLAGS_save_original_model != "true" && FLAGS_save_original_model != "false", - ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, - {"save_original_model", FLAGS_save_original_model}); - return domi::FAILED, "Input parameter[--save_original_model]'s value[%s] must be true or false.", - FLAGS_save_original_model.c_str()); - GE_CHK_BOOL_EXEC(ge::CheckBufferOptimizeParamValid(FLAGS_buffer_optimize) == ge::SUCCESS, return ge::FAILED, - "check output type failed!"); - - GE_CHK_BOOL_EXEC(ge::CheckEnableSingleStreamParamValid(std::string(FLAGS_enable_single_stream)) == ge::SUCCESS, - return ge::FAILED, "check enable single stream failed!"); - - return domi::SUCCESS; - } - - /** - * Verifying the parameters of converting model to JSON - * 1. Fmk_model - * 2. out_json - **/ - static Status CheckConverJsonParamFlags() { - // No model path passed in - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); - return domi::PARAM_INVALID, "Input parameter[--om]'s value is empty!!"); - - // JSON path not passed in - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"}); - return domi::PARAM_INVALID, "Input parameter[--json]'s value is empty!!"); - // Check if the model path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_om, "--om"), return domi::PARAM_INVALID, - "model file path is invalid: %s.", FLAGS_om.c_str()); - // Check whether the JSON path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_json, "--json"), return domi::PARAM_INVALID, - "json file path is invalid: %s.", FLAGS_json.c_str()); - - return domi::SUCCESS; - } - - /** - * Check command line parameters for explicit settings - * true: Explicit setup - * false: Not set up - * */ - static bool CheckFlagSet(string flag) { - gflags::CommandLineFlagInfo info; - return !(gflags::GetCommandLineFlagInfo(flag.c_str(), &info) && info.is_default); - } - - private: - static bool CheckEncryptModeValid(const int encrypt_mode) { -#if !defined(__ANDROID__) && !defined(ANDROID) - if (encrypt_mode != 0 && encrypt_mode != -1) { - DOMI_LOGE("encrypt mode must be 0 or -1"); - return false; - } -#else - if (encrypt_mode != -1) { - DOMI_LOGE("encrypt mode must be -1"); - return false; - } -#endif - - return true; - } - - static Status CheckFrameWorkValid(int framework, const std::string weight_file) { - if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW && - framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) { - // No framework information was passed in or the entered framework is illegal - ErrorManager::GetInstance().ATCReportErrMessage("E10007", {"parameter", "support"}, - {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow)"}); - DOMI_LOGE( - "Input parameter[--framework] is mandatory and it's value must be: " - "0(Caffe) or 1(MindSpore) or 3(TensorFlow)."); - return domi::PARAM_INVALID; - } - - if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) { - ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"}); - DOMI_LOGE("Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!"); - return domi::PARAM_INVALID; - } - - if ((framework == (int32_t)domi::TENSORFLOW) && (weight_file != "")) { - GELOGW("Parameter weight is ignored for TensorFlow."); - } - - if ((framework == (int32_t)domi::ONNX) && (weight_file != "")) { - GELOGW("Parameter weight is ignored for Onnx."); - } - return domi::SUCCESS; - } - - static bool CheckPathWithName(const std::string &fileName) { - // Determine file path length - if (fileName.size() > static_cast(PATH_MAX)) { - ErrorManager::GetInstance().ATCReportErrMessage("E10021", {"parameter", "size"}, - {"output", std::to_string(PATH_MAX)}); - GELOGE(ge::FAILED, "Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX); - return false; - } - - // Find the last separator - int slashPosition = fileName.size() - 1; - for (; slashPosition >= 0; slashPosition--) { - if (fileName[slashPosition] == '\\' || fileName[slashPosition] == '/') { - break; - } - } - - // Failure if no filename follows the path - if (slashPosition == static_cast(fileName.size() - 1)) { - ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName}); - DOMI_LOGE("Input parameter[--output]'s path[%s] not include file name", fileName.c_str()); - return false; - } - - return true; - } -}; - -void SetDynamicInputSizeOptions() { - if (!FLAGS_dynamic_batch_size.empty()) { - domi::GetContext().dynamic_batch_size = FLAGS_dynamic_batch_size; - } - if (!FLAGS_dynamic_image_size.empty()) { - domi::GetContext().dynamic_image_size = FLAGS_dynamic_image_size; - } - if (!FLAGS_dynamic_dims.empty()) { - domi::GetContext().dynamic_dims = FLAGS_dynamic_dims; - } -} - -static bool CheckInputFormat() { - if (FLAGS_input_format.empty()) { - // Set default format - if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { - FLAGS_input_format = "NHWC"; - } else { - FLAGS_input_format = "NCHW"; - } - return true; - } else if ((FLAGS_framework == static_cast(domi::CAFFE))) { // caffe - if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) { - return true; - } - // only support NCHW ND - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--input_format", FLAGS_input_format, ge::kCaffeFormatSupport}); - GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), - ge::kCaffeFormatSupport); - return false; - } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf - if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { - return true; - } - // only support NCHW NHWC ND NCDHW NDHWC - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--input_format", FLAGS_input_format, ge::kTFFormatSupport}); - GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kTFFormatSupport); - return false; - } else if (FLAGS_framework == static_cast(domi::ONNX)) { - if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { - return true; - } - // only support NCHW ND - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--input_format", FLAGS_input_format, ge::kONNXFormatSupport}); - GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kONNXFormatSupport); - return false; - } - return true; -} - -#if !defined(__ANDROID__) && !defined(ANDROID) -static void GetCustomOpPath(std::string &customop_path) { - GELOGI("Enter get custom op path schedule"); - std::string fmk_type = ge::TypeUtils::FmkTypeToSerialString(static_cast(FLAGS_framework)); - GELOGI("Framework type is %s.", fmk_type.c_str()); - - const char *path_env = std::getenv("ASCEND_OPP_PATH"); - if (path_env != nullptr) { - std::string path = path_env; - customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); - GELOGI("Get custom so path from env : %s", path_env); - return; - } - std::string path_base = ge::GELib::GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); - customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); - return; -} - -void GetPluginSoFileList(const string &path, vector &fileList, string &caffe_parser_path) { - // Support to split multiple so directories by ":" - GELOGI("path is %s", path.c_str()); - vector v_path = ge::StringUtils::Split(path, ':'); - for (size_t i = 0; i < v_path.size(); ++i) { - ge::FindParserSo(v_path[i], fileList, caffe_parser_path); - GELOGI("CustomOpLib full name = %s", v_path[i].c_str()); - } -} - -void LoadModelParserLib(std::string caffe_parser_path) { - if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { - void *tf_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); - if (tf_handle == nullptr) { - GELOGW("dlopen fmk library [libfmk_parser.so] failed."); - return; - } - GELOGI("plugin load libfmk_parser.so success."); - } else if (FLAGS_framework == static_cast(domi::CAFFE)) { - // What we are dealing with here is that the user modifies the caffe.proto scenario. - // If no lib_Caffe_Parser.so is found under the plugin path, use the default lib_Caffe_Parser.so path. - caffe_parser_path = caffe_parser_path.empty() ? "lib_caffe_parser.so" : caffe_parser_path; - - void *handle = dlopen(caffe_parser_path.c_str(), RTLD_NOW | RTLD_GLOBAL); - if (handle == nullptr) { - GELOGW("dlopen failed, plugin name:%s. Message(%s).", caffe_parser_path.c_str(), dlerror()); - return; - } - GELOGI("plugin load %s success.", caffe_parser_path.c_str()); - // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_parser.so). - // (depend on the lib_caffe_parser.so) - void *fmk_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); - if (fmk_handle == nullptr) { - GELOGW("dlopen fmk library [libfmk_parser.so] failed."); - if (dlclose(handle) != 0) { - GELOGW("dlclose lib_caffe_parser.so failed."); - } - return; - } - GELOGI("plugin load libfmk_parser.so success."); - } else if (FLAGS_framework == static_cast(domi::ONNX)) { - void *handle = dlopen("libfmk_onnx_parser.so", RTLD_NOW | RTLD_GLOBAL); - if (handle == nullptr) { - GELOGW("dlopen fmk library [libfmk_onnx_parser.so] failed."); - return; - } - GELOGI("plugin load libfmk_onnx_parser.so success."); - } else { - GELOGW("Framework:%s is not support.", - ge::TypeUtils::FmkTypeToSerialString(static_cast(FLAGS_framework)).c_str()); - return; - } - return; -} - -void LoadCustomOpLib(bool need_load_ops_plugin) { - std::string plugin_path; - GetCustomOpPath(plugin_path); - - vector fileList; - string caffe_parser_path = ""; - - // whether there are files in the plugin so path - GetPluginSoFileList(plugin_path, fileList, caffe_parser_path); - - // no file - if (fileList.empty() && caffe_parser_path.empty()) { - GELOGW("can not find any plugin file in plugin_path: %s", plugin_path.c_str()); - } - - LoadModelParserLib(caffe_parser_path); - if (!need_load_ops_plugin) { - GELOGI("No need to load ops plugin so."); - return; - } - OpRegistry::Instance()->registrationDatas.clear(); - // load other so files except lib_caffe_parser.so in the plugin so path - for (auto elem : fileList) { - ge::StringUtils::Trim(elem); - - void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); - if (handle == nullptr) { - GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); - } else { - GELOGI("plugin load %s success.", elem.c_str()); - } - } - - std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; - for (OpRegistrationData reg_data : registrationDatas) { - if (reg_data.GetFrameworkType() == static_cast(FLAGS_framework)) { - (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); - (void)OpRegistry::Instance()->Register(reg_data); - } - } -} - -void SaveCustomCaffeProtoPath() { - GELOGI("Enter save custom caffe proto path."); - - std::string path_base = ge::GELib::GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); - ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; - - string customop_path; - const char *path_env = std::getenv("ASCEND_OPP_PATH"); - if (path_env != nullptr) { - std::string path = path_env; - customop_path = path + "/framework/custom/caffe/"; - GELOGI("Get custom proto path from env : %s", path_env); - ge::GetParserContext().custom_proto_path = customop_path; - return; - } - customop_path = path_base + "ops/framework/custom/caffe/"; - ge::GetParserContext().custom_proto_path = customop_path; - return; -} - -#endif - -Status CreateInputsForInference(const ge::Graph &graph, vector &inputs) { - auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); - GE_CHECK_NOTNULL(compute_graph); - for (ge::NodePtr &input_node : compute_graph->GetAllNodes()) { - GE_CHECK_NOTNULL(input_node); - ge::OpDescPtr op = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op); - if (op->GetType() == ge::DATA) { - GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); - ge::GeTensorDesc tensor = op->GetInputDesc(0); - string data_op_name = op->GetName(); - GELOGI("Data op name is: %s", data_op_name.c_str()); - ge::GeShape data_shape; - auto iter = GetContext().input_dims.find(data_op_name); - if (iter != GetContext().input_dims.end()) { - data_shape = ge::GeShape(iter->second); - GELOGI("Data op get shape from Context."); - } else { - data_shape = tensor.GetShape(); - GELOGI("Data op get shape from InputDesc in geir graph."); - } - - ge::DataType data_type = tensor.GetDataType(); - string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type); - GELOGI("Data op get data type:%s from InputDesc in geir graph.", data_type_str.c_str()); - - ge::GeTensor input_tensor; - ge::GeTensorDesc desc(data_shape, ge::Format(GetContext().format), data_type); - input_tensor.SetTensorDesc(desc); - inputs.push_back(input_tensor); - } - } - GELOGI("Build ME model, inputs size is: %zu", inputs.size()); - return ge::SUCCESS; -} - -domi::Status GenerateInfershapeJson() { - if (!CheckInputFormat()) { - GELOGE(ge::FAILED, "Check input_format failed"); - return domi::FAILED; - } - Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags(); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!"); - - ge::GeGenerator ge_generator; - std::map options; - ge::Status geRet = ge_generator.Initialize(options); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("GeGenerator initialize failed!"); - return domi::FAILED; - } - - ge::Graph graph; - std::map atc_params; - atc_params.insert(std::pair("input_format", FLAGS_input_format)); - ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType)FLAGS_framework, "", - FLAGS_target.c_str(), (ge::RunMode)FLAGS_mode, false); - if (ret != ge::SUCCESS) { - DOMI_LOGE("ATC Parse graph domi::FAILED"); - (void)ge_generator.Finalize(); - return domi::FAILED; - } - - geRet = ge_generator.GenerateInfershapeGraph(graph); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("ATC GenerateInfershapeJson failed"); - (void)ge_generator.Finalize(); - return domi::FAILED; - } - if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) { - DOMI_LOGE("ATC DumpInfershapeJson failed"); - (void)ge_generator.Finalize(); - return domi::FAILED; - } - (void)ge_generator.Finalize(); - return ge::SUCCESS; -} - -static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) { - Status ret = domi::SUCCESS; - if (fwk_type == -1) { - ret = ge::ConvertOmModelToJson(model_file.c_str(), json_file.c_str()); - return ret; - } - - if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--framework", std::to_string(fwk_type), kModelToJsonSupport}); - GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport); - return ge::FAILED; - } - - if (FLAGS_dump_mode == "0") { - // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so. - LoadCustomOpLib(false); - ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str()); - return ret; - } else if (FLAGS_dump_mode == "1") { - // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so and ops plugin so. - LoadCustomOpLib(true); - ret = GenerateInfershapeJson(); - return ret; - } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"}); - GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0."); - return ge::FAILED; - } -} - -domi::Status GenerateModel(std::map &options, std::string output) { - ge::GeGenerator ge_generator; - ge::Status geRet = ge::SUCCESS; - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - geRet = ge::GELib::Initialize(options); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("GE initialize failed!"); - return domi::FAILED; - } - } - geRet = ge_generator.Initialize(options); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("GeGenerator initialize failed!"); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - ge::Graph graph; - std::vector inputs; - if (FLAGS_framework == domi::MINDSPORE) { - // load model from file - ge::Model load_model = ge::Model("loadmodel", "version2"); - auto ret1 = load_model.LoadFromFile(FLAGS_model); - if (ret1 != ge::GRAPH_SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); - DOMI_LOGE( - "Load model from %s failed, please check model file or " - "input parameter[--framework] is correct", - FLAGS_model.c_str()); - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - graph = load_model.GetGraph(); - - GE_CHK_STATUS_EXEC(ge::InitDomiOmgContext(FLAGS_input_shape, FLAGS_input_format, "", is_dynamic_input), - GELOGE(ge::FAILED, "ATC Generate call InitDomiOmgContext ret fail"); - (void)ge_generator.Finalize(); (void)ge::GELib::GetInstance()->Finalize(); return domi::FAILED); - - Status ret = CreateInputsForInference(graph, inputs); - if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "create inputs for inference failed."); - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - if (SetOutputNodeInfo(graph, "", "") != domi::SUCCESS) { - GELOGE(ge::FAILED, "Set output node info fail."); - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - } else { - std::map atc_params; - atc_params.insert(std::pair("input_shape", FLAGS_input_shape)); - atc_params.insert(std::pair("out_nodes", FLAGS_out_nodes)); - atc_params.insert(std::pair("input_format", FLAGS_input_format)); - atc_params.insert(std::pair("check_report", FLAGS_check_report)); - atc_params.insert(std::pair("input_fp16_nodes", FLAGS_input_fp16_nodes)); - atc_params.insert(std::pair("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); - atc_params.insert(std::pair("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); - atc_params.insert(std::pair("compress_weight_conf", FLAGS_compress_weight_conf)); - atc_params.insert(std::pair(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); - atc_params.insert(std::pair("output", output)); - - Status ret = - ParseGraph(graph, atc_params, FLAGS_model.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType)FLAGS_framework, - FLAGS_op_name_map.c_str(), FLAGS_target.c_str(), (ge::RunMode)FLAGS_mode, is_dynamic_input); - - // in ONLY_PRE_CHECK mode, pre-checking report has already saved in ParseGraph - if (FLAGS_mode == ge::ONLY_PRE_CHECK) { - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - if (ret != ge::SUCCESS) { - DOMI_LOGE("ATC precheck fail."); - return domi::FAILED; - } - return domi::SUCCESS; - } - - if (ret != ge::SUCCESS) { - DOMI_LOGE("ATC Parse graph domi::FAILED"); - DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) { - DOMI_LOGE("Set output node info fail."); - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - } - - geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("GE GenerateOfflineModel execute failed"); - DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case - // checking error log) - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return ge::SUCCESS; -} - -static void SetEnvForSingleOp(std::map &options) { - string flag_on = "1"; - string flag_off = "0"; - options.emplace(ge::GE_FE_FLAG, flag_on); - options.emplace(ge::STREAM_NUM, "1"); // single op only use one stream - options.emplace(ge::RUN_FLAG, flag_off); - options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off); - options.emplace(ge::SINGLE_OP_FLAG, flag_on); - options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode); - options.emplace(ge::SOC_VERSION, FLAGS_soc_version); - options.emplace(ge::CORE_TYPE, FLAGS_core_type); - options.emplace(ge::AICORE_NUM, FLAGS_aicore_num); - options.emplace(ge::OP_SELECT_IMPL_MODE, FLAGS_op_select_implmode); - options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode); - options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode); - options.emplace(ge::GRAPH_MEMORY_MAX_SIZE, kGraphMemoryManagerMallocMaxSize); - options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level)); -} - -domi::Status GenerateSingleOp(const std::string &json_file_path) { - if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { - DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str()); - return domi::FAILED; - } - // check optypelist_for_implmode and op_select_implmode - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, - return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); - - std::map options; - // need to be changed when ge.ini plan is done - SetEnvForSingleOp(options); - - auto ret = ge::GELib::Initialize(options); - if (ret != ge::SUCCESS) { - DOMI_LOGE("GE initialize failed!"); - return domi::FAILED; - } - - ge::GeGenerator generator; - ret = generator.Initialize(options); - if (ret != SUCCESS) { - DOMI_LOGE("GeGenerator initialize failed!"); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - vector build_params; - if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) { - DOMI_LOGE("parse single op json file failed"); - (void)generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - int index = 0; - for (auto ¶m : build_params) { - string output_path; - if (!FLAGS_output.empty()) { - output_path = FLAGS_output + "/"; - } - output_path += param.file_name; - ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path); - if (ret != SUCCESS) { - DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index); - ret = domi::FAILED; - break; - } - GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str()); - index += 1; - } - - (void)generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return ret; -} - -domi::Status GenerateOmModel() { - if (!CheckInputFormat()) { - GELOGE(ge::FAILED, "Check input_format failed"); - return domi::FAILED; - } - Status ret = GFlagUtils::CheckFlags(); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, - "Check flags failed! Please check whether some atc params that include semicolons[;] use double " - "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size"); -#if !defined(__ANDROID__) && !defined(ANDROID) - // Load custom operator Library - LoadCustomOpLib(true); - - SaveCustomCaffeProtoPath(); - - ret = ge::CheckCustomAiCpuOpLib(); - - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!"); -#endif - - const int f_stream_num = 1; - std::map options; - options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(FLAGS_framework))); - options.insert(std::pair(string(ge::STREAM_NUM), to_string(f_stream_num))); - options.insert(std::pair(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf)); - options.insert(std::pair(string(ge::ENCRYPT_MODE), to_string(FLAGS_encrypt_mode))); - options.insert(std::pair(string(ge::EK_FILE), FLAGS_encrypt_key)); - options.insert(std::pair(string(ge::CERT_FILE), FLAGS_certificate)); - options.insert(std::pair(string(ge::HW_KEY_FILE), FLAGS_hardware_key)); - options.insert(std::pair(string(ge::PRIVATE_KEY_FILE), FLAGS_private_key)); - options.insert(std::pair(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); - options.insert(std::pair(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); - options.insert(std::pair(string(ge::PRECISION_MODE), FLAGS_precision_mode)); - - options.insert(std::pair(string(ge::RUN_FLAG), to_string(0))); - options.insert(std::pair(string(ge::TRAIN_FLAG), to_string(0))); - - if (!FLAGS_output_type.empty()) { - options.insert(std::pair(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); - } - - options.insert(std::pair(string(ge::OP_SELECT_IMPL_MODE), FLAGS_op_select_implmode)); - options.insert(std::pair(string(ge::OPTYPELIST_FOR_IMPLMODE), FLAGS_optypelist_for_implmode)); - - if (!FLAGS_input_fp16_nodes.empty()) { - GELOGI("FLAGS_input_fp16_nodes : %s .", FLAGS_input_fp16_nodes.c_str()); - options.insert(std::pair(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes)); - } - - options.insert(std::pair(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode)); - - options.insert( - std::pair(string(ge::OPTION_EXEC_DISABLE_REUSED_MEMORY), to_string(FLAGS_disable_reuse_memory))); - - options.insert(std::pair(string(ge::SOC_VERSION), FLAGS_soc_version)); - - options.insert(std::pair(string(ge::CORE_TYPE), FLAGS_core_type)); - - options.insert(std::pair(string(ge::AICORE_NUM), FLAGS_aicore_num)); - - options.insert(std::pair(string(ge::BUFFER_OPTIMIZE), FLAGS_buffer_optimize)); - - options.insert(std::pair(string(ge::ENABLE_SMALL_CHANNEL), FLAGS_enable_small_channel)); - - options.insert(std::pair(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file)); - - options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), (FLAGS_enable_compress_weight == "true") - ? ge::kEnableCompressWeightTrue - : ge::kEnableCompressWeightFalse)); - - options.insert(std::pair(string(ge::GRAPH_MEMORY_MAX_SIZE), kGraphMemoryManagerMallocMaxSize)); - - options.insert(std::pair(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream)); - - SetDynamicInputSizeOptions(); - - if (!FLAGS_save_original_model.empty()) { - options.insert(std::pair(string(ge::SAVE_ORIGINAL_MODEL), FLAGS_save_original_model)); - options.insert(std::pair(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om")); - } - - options.insert(std::pair(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); - - // print atc option map - ge::PrintOptionMap(options, "atc option"); - - // When the ATC module is transferred to a model, the suffix ".om" is automatically added to the model name - FLAGS_output = FLAGS_output + ".om"; - ret = GenerateModel(options, FLAGS_output); - if (ret != domi::SUCCESS) { - return domi::FAILED; - } - - return domi::SUCCESS; -} - -domi::Status ConvertModelToJson() { - Status ret = GFlagUtils::CheckConverJsonParamFlags(); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!"); - - ret = ConvertModelToJson(FLAGS_framework, FLAGS_om, FLAGS_json); - - GE_IF_BOOL_EXEC(ret != domi::SUCCESS, return domi::FAILED); - return domi::SUCCESS; -} - -bool CheckRet(domi::Status ret) { - if (ret != domi::SUCCESS) { - if (FLAGS_mode == ONLY_PRE_CHECK) { - GELOGW("ATC precheck failed."); - } else if (FLAGS_mode == GEN_OM_MODEL) { - GELOGW("ATC generate offline model failed."); - } else if (FLAGS_mode == MODEL_TO_JSON) { - GELOGW("ATC convert model to json file failed."); - } else if (FLAGS_mode == PBTXT_TO_JSON) { - GELOGW("ATC convert pbtxt to json file failed."); - } else { - return false; - } - return false; - } - - if (FLAGS_mode == ONLY_PRE_CHECK) { - GELOGI("ATC precheck success."); - } else if (FLAGS_mode == GEN_OM_MODEL) { - GELOGI("ATC generate offline model success."); - } else if (FLAGS_mode == MODEL_TO_JSON) { - GELOGI("ATC convert model to json file success."); - } else if (FLAGS_mode == PBTXT_TO_JSON) { - GELOGI("ATC convert pbtxt to json file success."); - } - return true; -} - -domi::Status ConvertPbtxtToJson() { - Status ret = GFlagUtils::CheckConverJsonParamFlags(); - if (ret != domi::SUCCESS) { - GELOGE(ge::FAILED, "Check convert json params flags failed!"); - return domi::FAILED; - } - - ret = ge::ConvertPbtxtToJson(FLAGS_om.c_str(), FLAGS_json.c_str()); - if (ret != domi::SUCCESS) { - GELOGE(ge::FAILED, "ConvertPbtxtToJson fail."); - return domi::FAILED; - } - - return domi::SUCCESS; -} - -int init(int argc, char *argv[]) { - GFlagUtils::InitGFlag(argc, argv); - // set log level - int ret = -1; - const std::set log_level = {"null", "debug", "info", "warning", "error"}; - if (log_level.count(FLAGS_log) == 0) { - std::cout << "E10010: invalid value for --log:" << FLAGS_log << ", only support debug, info, warning, error, null" - << std::endl; - return ret; - } - - ret = ge::CheckLogParamValidAndSetLogLevel(FLAGS_log); - if (ret != 0) { - return ret; - } - - std::string path_base = ge::GELib::GetPath(); - ret = ErrorManager::GetInstance().Init(path_base); - if (ret != 0) { - DOMI_LOGE("ErrorManager init fail !"); - return ret; - } - - return 0; -} - -int main(int argc, char *argv[]) { - Status ret = domi::SUCCESS; - std::cout << "ATC start working now, please wait for a moment." << std::endl; - try { - // Initialize - if (init(argc, argv) != 0) { - std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; - return -1; - } - - do { - if (!FLAGS_singleop.empty()) { - ret = GenerateSingleOp(FLAGS_singleop); - break; - } - - // default mode(mode:0), Open source model to model - if (GEN_OM_MODEL == FLAGS_mode || ONLY_PRE_CHECK == FLAGS_mode) { - GE_IF_BOOL_EXEC(GenerateOmModel() != domi::SUCCESS, ret = domi::FAILED; break); - } else if (MODEL_TO_JSON == FLAGS_mode) { // Mode 1, transfer model to JSON - GE_CHK_BOOL_EXEC(ConvertModelToJson() == domi::SUCCESS, ret = domi::FAILED; - break, "ATC ConvertJson execute failed!!"); - } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { - GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; - break, "ATC convert pbtxt to json execute failed!!"); - } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--mode", std::to_string(FLAGS_mode), kModeSupport}); - GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport); - ret = domi::FAILED; - break; - } - } while (0); - } catch (std::bad_alloc) { - ret = domi::FAILED; - DOMI_LOGE("ATC run failed, bad memory allocation occur !"); - std::cout << "ATC run failed, bad memory allocation occur !" << std::endl; - } catch (...) { - ret = domi::FAILED; - DOMI_LOGE("ATC run failed, some exceptions occur !"); - std::cout << "ATC run failed, some exceptions occur !" << std::endl; - } - - if (!CheckRet(ret)) { - std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; - int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO); - if (result != 0) { - DOMI_LOGE("ErrorManager outputErrMessage fail !"); - } - return ret; - } else { - std::cout << "ATC run success, welcome to the next use." << std::endl; - (void)ErrorManager::GetInstance().OutputMessage(STDOUT_FILENO); - return 0; - } -} /*lint +e530*/ diff --git a/src/ge/offline/module.mk b/src/ge/offline/module.mk deleted file mode 100644 index a347362a..00000000 --- a/src/ge/offline/module.mk +++ /dev/null @@ -1,53 +0,0 @@ - -LOCAL_PATH := $(call my-dir) - -include $(CLEAR_VARS) - -LOCAL_MODULE := atc - -LOCAL_CFLAGS += -Werror -LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 - -LOCAL_SRC_FILES := \ - main.cc \ - single_op_parser.cc \ - ../session/omg.cc \ - ../ir_build/atc_ir_common.cc \ - -LOCAL_C_INCLUDES := \ - $(LOCAL_PATH)/../ ./ \ - $(TOPDIR)inc \ - $(TOPDIR)inc/external \ - $(TOPDIR)inc/external/graph \ - $(TOPDIR)inc/framework \ - $(TOPDIR)inc/framework/domi \ - $(TOPDIR)libc_sec/include \ - $(TOPDIR)inc/common/util \ - third_party/json/include \ - third_party/gflags/include \ - third_party/protobuf/include \ - proto/om.proto \ - proto/ge_ir.proto \ - proto/task.proto \ - proto/insert_op.proto \ - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libge_common \ - libprotobuf \ - libslog \ - libgraph \ - libregister \ - liberror_manager \ - libge_compiler \ - libruntime_compile \ - libparser_common \ - libfmk_parser \ - liberror_manager \ - -LOCAL_STATIC_LIBRARIES := libgflags - -LOCAL_LDFLAGS := -lrt -ldl - -include $(BUILD_HOST_EXECUTABLE) - diff --git a/src/ge/offline/single_op_parser.cc b/src/ge/offline/single_op_parser.cc deleted file mode 100644 index 54b6df69..00000000 --- a/src/ge/offline/single_op_parser.cc +++ /dev/null @@ -1,426 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "single_op_parser.h" - -#include -#include -#include -#include - -#include - -#include "framework/common/debug/ge_log.h" -#include "common/util/error_manager/error_manager.h" -#include "common/ge_inner_error_codes.h" -#include "framework/common/util.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/operator_factory_impl.h" - -using Json = nlohmann::json; -using std::map; -using std::string; -using std::vector; - -namespace ge { -namespace { -constexpr char const *kKeyOp = "op"; -constexpr char const *kKeyInputDesc = "input_desc"; -constexpr char const *kKeyOutputDesc = "output_desc"; -constexpr char const *kKeyAttr = "attr"; -constexpr char const *kKeyName = "name"; -constexpr char const *kKeyType = "type"; -constexpr char const *kKeyShape = "shape"; -constexpr char const *kKeyShapeRange = "shape_range"; -constexpr char const *kKeyValue = "value"; -constexpr char const *kKeyFormat = "format"; -constexpr char const *kFileSuffix = ".om"; -constexpr int kDumpJsonIndent = 2; -constexpr int kShapeRangePairSize = 2; -constexpr int kShapeRangeLow = 0; -constexpr int kShapeRangeHigh = 1; - -map kAttrTypeDict = { - {"bool", GeAttrValue::VT_BOOL}, - {"int", GeAttrValue::VT_INT}, - {"float", GeAttrValue::VT_FLOAT}, - {"string", GeAttrValue::VT_STRING}, - {"list_bool", GeAttrValue::VT_LIST_BOOL}, - {"list_int", GeAttrValue::VT_LIST_INT}, - {"list_float", GeAttrValue::VT_LIST_FLOAT}, - {"list_string", GeAttrValue::VT_LIST_STRING}, - {"list_list_int", GeAttrValue::VT_LIST_LIST_INT}, -}; - -map kDataTypeDict = { - {"bool", DT_BOOL}, {"int8", DT_INT8}, {"uint8", DT_UINT8}, {"int16", DT_INT16}, {"uint16", DT_UINT16}, - {"int32", DT_INT32}, {"uint32", DT_UINT32}, {"int64", DT_INT64}, {"uint64", DT_UINT64}, {"float16", DT_FLOAT16}, - {"half", DT_FLOAT16}, {"fp16", DT_FLOAT16}, {"float", DT_FLOAT}, {"float32", DT_FLOAT}, {"double", DT_DOUBLE}, -}; - -map kFormatDict = { - {"nchw", FORMAT_NCHW}, {"nhwc", FORMAT_NHWC}, {"nd", FORMAT_ND}, {"fractal_nz", FORMAT_FRACTAL_NZ}, - {"fractal_z", FORMAT_FRACTAL_Z}, {"nc1hwc0", FORMAT_NC1HWC0}, -}; -} // namespace - -template -void SetAttrValue(const Json &j, SingleOpAttr &attr) { - attr.value.SetValue(j.at(kKeyValue).get()); -} - -template -T GetValue(const map &dict, string &key, T default_val) { - transform(key.begin(), key.end(), key.begin(), ::tolower); - auto it = dict.find(key); - if (it == dict.end()) { - return default_val; - } - - return it->second; -} - -void from_json(const Json &j, SingleOpTensorDesc &desc) { - desc.dims = j.at(kKeyShape).get>(); - auto it = j.find(kKeyShapeRange); - if (it != j.end()) { - desc.dim_ranges = j.at(kKeyShapeRange).get>>(); - } - string format_str = j.at(kKeyFormat).get(); - string type_str = j.at(kKeyType).get(); - desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); - desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); - auto tensor_name = j.find(kKeyName); - if (tensor_name != j.end()) { - desc.name = tensor_name->get(); - } -} - -void from_json(const Json &j, SingleOpAttr &attr) { - attr.name = j.at(kKeyName).get(); - attr.type = j.at(kKeyType).get(); - auto it = kAttrTypeDict.find(attr.type); - if (it == kAttrTypeDict.end()) { - GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str()); - return; - } - - switch (it->second) { - case GeAttrValue::VT_BOOL: - SetAttrValue(j, attr); - break; - case GeAttrValue::VT_INT: - SetAttrValue(j, attr); - break; - case GeAttrValue::VT_FLOAT: - SetAttrValue(j, attr); - break; - case GeAttrValue::VT_STRING: - SetAttrValue(j, attr); - break; - case GeAttrValue::VT_LIST_BOOL: - SetAttrValue>(j, attr); - break; - case GeAttrValue::VT_LIST_INT: - SetAttrValue>(j, attr); - break; - case GeAttrValue::VT_LIST_FLOAT: - SetAttrValue>(j, attr); - break; - case GeAttrValue::VT_LIST_STRING: - SetAttrValue>(j, attr); - break; - case GeAttrValue::VT_LIST_LIST_INT: - SetAttrValue>>(j, attr); - break; - default: - GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str()); - break; - } -} - -void from_json(const Json &j, SingleOpDesc &desc) { - desc.op = j.at(kKeyOp).get(); - - auto input_desc = j.find(kKeyInputDesc); - if (input_desc != j.end()) { - desc.input_desc = input_desc->get>(); - } - - auto output_desc = j.find(kKeyOutputDesc); - if (output_desc != j.end()) { - desc.output_desc = output_desc->get>(); - } - - auto attr_field = j.find(kKeyAttr); - if (attr_field != j.end()) { - desc.attrs = attr_field->get>(); - } -} - -Status SingleOpParser::ReadJsonFile(const std::string &file, Json &json_obj) { - std::string real_path = RealPath(file.c_str()); - if (real_path.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10023", {"value"}, {file}); - GELOGE(FAILED, "Input parameter[--singleop]'s value[%s] is not a valid path.", file.c_str()); - return INTERNAL_ERROR; - } - - std::ifstream ifs(real_path); - if (!ifs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10024", {"value"}, {file}); - GELOGE(FAILED, "Open file[%s] provided in input parameter[--singleop] failed.", file.c_str()); - return FAILED; - } - try { - ifs >> json_obj; - } catch (const std::exception &e) { - ErrorManager::GetInstance().ATCReportErrMessage("E10025", {"realpath", "errmsg"}, {real_path, e.what()}); - GELOGE(PARAM_INVALID, "Parse file[%s] provided in input parameter[--singleop] failed, exception = %s.", - real_path.c_str(), e.what()); - return PARAM_INVALID; - } - - ifs.close(); - return SUCCESS; -} - -bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { - if (op_desc.op.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10026"); - GELOGE(PARAM_INVALID, "Op name is empty"); - return false; - } - - int index = 0; - for (auto &tensor_desc : op_desc.input_desc) { - if (tensor_desc.type == DT_UNDEFINED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(false, "Input's dataType is invalid when the index is %d", index); - return false; - } - - if (tensor_desc.format == FORMAT_RESERVED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); - return false; - } - ++index; - } - - index = 0; - for (auto &tensor_desc : op_desc.output_desc) { - if (tensor_desc.type == DT_UNDEFINED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); - return false; - } - - if (tensor_desc.format == FORMAT_RESERVED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index); - return false; - } - ++index; - } - - for (auto &attr : op_desc.attrs) { - if (attr.name.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10029"); - GELOGE(PARAM_INVALID, "attr name is empty"); - return false; - } - - if (attr.value.IsEmpty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10030", {"attrname"}, {attr.name}); - GELOGE(PARAM_INVALID, "Parse attr \"%s\" failed. ", attr.name.c_str()); - return false; - } - } - - return true; -} - -std::unique_ptr SingleOpParser::CreateOpDesc(const string &op_type) { - return std::unique_ptr(new (std::nothrow) OpDesc(op_type, op_type)); -} - -Status SingleOpParser::ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, - SingleOpBuildParam &build_param) { - auto op_desc = CreateOpDesc(single_op_desc.op); - if (op_desc == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to create instance of opDesc"); - return MEMALLOC_FAILED; - } - - std::stringstream file_name; - file_name << index; - file_name << "_" << single_op_desc.op; - for (auto &desc : single_op_desc.input_desc) { - file_name << "_" << desc.type << "_" << desc.format; - for (auto dim : desc.dims) { - file_name << "_" << dim; - } - GeTensorDesc ge_tensor_desc(GeShape(desc.dims), desc.format, desc.type); - ge_tensor_desc.SetOriginFormat(desc.format); - GE_CHK_STATUS_RET_NOLOG(SetShapeRange(desc, ge_tensor_desc)); - TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); - TensorUtils::SetInputTensor(ge_tensor_desc, true); - TensorUtils::SetOutputTensor(ge_tensor_desc, false); - if (desc.name.empty()) { - op_desc->AddInputDesc(ge_tensor_desc); - } else { - op_desc->AddInputDesc(desc.name, ge_tensor_desc); - } - build_param.inputs.emplace_back(ge_tensor_desc); - } - - for (auto &desc : single_op_desc.output_desc) { - file_name << "_" << desc.type << "_" << desc.format; - for (auto dim : desc.dims) { - file_name << "_" << dim; - } - - GeTensorDesc ge_tensor_desc(GeShape(desc.dims), desc.format, desc.type); - ge_tensor_desc.SetOriginFormat(desc.format); - GE_CHK_STATUS_RET_NOLOG(SetShapeRange(desc, ge_tensor_desc)); - TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); - TensorUtils::SetInputTensor(ge_tensor_desc, false); - TensorUtils::SetOutputTensor(ge_tensor_desc, true); - op_desc->AddOutputDesc(ge_tensor_desc); - build_param.outputs.emplace_back(ge_tensor_desc); - } - - for (const auto &attr : single_op_desc.attrs) { - op_desc->SetAttr(attr.name, attr.value); - } - - if (VerifyOpInputOutputSizeByIr(*op_desc) != SUCCESS) { - GELOGE(PARAM_INVALID, "Verify op [%s] input or output size failed.", op_desc->GetType().c_str()); - return PARAM_INVALID; - } - - file_name << kFileSuffix; - build_param.file_name = file_name.str(); - build_param.op_desc.reset(op_desc.release()); - return SUCCESS; -} - -Status SingleOpParser::VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc) { - ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("tmp_operator", current_op_desc.GetType()); - if (!operator_ir.IsEmpty()) { - auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir); - GE_CHECK_NOTNULL(opdesc_ir); - size_t current_opdesc_inputs_num = current_op_desc.GetInputsSize(); - size_t ir_opdesc_inputs_num = opdesc_ir->GetInputsSize(); - if (current_opdesc_inputs_num < ir_opdesc_inputs_num) { - string reason = "is smaller than the ir needed input size " + std::to_string(ir_opdesc_inputs_num); - ErrorManager::GetInstance().ATCReportErrMessage( - "E19014", {"opname", "value", "reason"}, - {current_op_desc.GetName(), "input size " + std::to_string(current_opdesc_inputs_num), reason}); - GELOGE(PARAM_INVALID, "This op [%s] input size %zu is smaller than the ir needed input size %zu", - current_op_desc.GetName().c_str(), current_opdesc_inputs_num, ir_opdesc_inputs_num); - return PARAM_INVALID; - } - size_t current_opdesc_outputs_num = current_op_desc.GetOutputsSize(); - size_t ir_opdesc_outputs_num = opdesc_ir->GetOutputsSize(); - if (current_opdesc_outputs_num < ir_opdesc_outputs_num) { - string reason = "is smaller than the ir needed output size " + std::to_string(ir_opdesc_outputs_num); - ErrorManager::GetInstance().ATCReportErrMessage( - "E19014", {"opname", "value", "reason"}, - {current_op_desc.GetName(), "output size " + std::to_string(current_opdesc_outputs_num), reason}); - GELOGE(PARAM_INVALID, "This op [%s] output size %zu is smaller than the ir needed output size %zu", - current_op_desc.GetName().c_str(), current_opdesc_outputs_num, ir_opdesc_outputs_num); - return PARAM_INVALID; - } - } - return SUCCESS; -} - -Status SingleOpParser::SetShapeRange(const SingleOpTensorDesc &tensor_desc, GeTensorDesc &ge_tensor_desc) { - if (tensor_desc.dim_ranges.empty()) { - return SUCCESS; - } - - std::vector> shape_range; - size_t range_index = 0; - for (auto dim : tensor_desc.dims) { - if (dim >= 0) { - shape_range.emplace_back(dim, dim); - GELOGD("Adding shape range: [%ld, %ld]", dim, dim); - } else { - if (range_index >= tensor_desc.dim_ranges.size()) { - GELOGE(PARAM_INVALID, "The number of shape_range mismatches that of unknown dims."); - return PARAM_INVALID; - } - - auto &range = tensor_desc.dim_ranges[range_index]; - if (range.size() != kShapeRangePairSize) { - GELOGE(PARAM_INVALID, "Invalid shape range entry. index = %zu, size = %zu", range_index, range.size()); - return PARAM_INVALID; - } - - shape_range.emplace_back(range[kShapeRangeLow], range[kShapeRangeHigh]); - GELOGD("Adding shape range: [%ld, %ld]", range[kShapeRangeLow], range[kShapeRangeHigh]); - ++range_index; - } - } - - ge_tensor_desc.SetShapeRange(shape_range); - return SUCCESS; -} - -Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector &op_list) { - Json single_op_list_json; - auto ret = ReadJsonFile(file, single_op_list_json); - if (ret != SUCCESS) { - return ret; - } - - int index = 0; - for (const Json &single_op_json : single_op_list_json) { - GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str()); - SingleOpDesc single_op_desc; - try { - single_op_desc = single_op_json; - } catch (const nlohmann::json::exception &e) { - ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"index", "jsonfile", "exception"}, - {std::to_string(index), file, e.what()}); - GELOGE(PARAM_INVALID, "Parse the index[%d] of op failed when read json file[%s], exception %s, jsonStr %s", index, - file.c_str(), e.what(), single_op_json.dump(kDumpJsonIndent).c_str()); - return PARAM_INVALID; - } - - if (!Validate(single_op_desc)) { - GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str()); - return PARAM_INVALID; - } - - SingleOpBuildParam param; - ret = ConvertToBuildParam(index, single_op_desc, param); - if (ret != SUCCESS) { - return ret; - } - - op_list.emplace_back(param); - GELOGI("Parse the index[%d] of op success", index); - index += 1; - } - - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/offline/single_op_parser.h b/src/ge/offline/single_op_parser.h deleted file mode 100644 index 472e1041..00000000 --- a/src/ge/offline/single_op_parser.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef ACL_TOOLS_COMPILE_PARSER_H -#define ACL_TOOLS_COMPILE_PARSER_H - -#include -#include - -#include - -#include "ge/ge_api_error_codes.h" -#include "graph/types.h" -#include "graph/ge_attr_value.h" -#include "graph/op_desc.h" - -namespace ge { -struct SingleOpTensorDesc { - std::string name; - std::vector dims; - std::vector> dim_ranges; - ge::Format format = ge::FORMAT_RESERVED; - ge::DataType type = ge::DT_UNDEFINED; -}; - -struct SingleOpAttr { - std::string name; - std::string type; - ge::GeAttrValue value; -}; - -struct SingleOpDesc { - std::string op; - std::vector input_desc; - std::vector output_desc; - std::vector attrs; -}; - -struct SingleOpBuildParam { - ge::OpDescPtr op_desc; - std::vector inputs; - std::vector outputs; - std::string file_name; -}; - -void from_json(const nlohmann::json &json, SingleOpTensorDesc &desc); - -void from_json(const nlohmann::json &json, SingleOpAttr &desc); - -void from_json(const nlohmann::json &json, SingleOpDesc &desc); - -class SingleOpParser { - public: - static Status ParseSingleOpList(const std::string &file, std::vector &op_list); - - private: - static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj); - static bool Validate(const SingleOpDesc &op_desc); - static std::unique_ptr CreateOpDesc(const std::string &op_type); - static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); - static Status VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc); - static Status SetShapeRange(const SingleOpTensorDesc &tensor_desc, GeTensorDesc &ge_tensor_desc); -}; -} // namespace ge - -#endif // ACL_TOOLS_COMPILE_PARSER_H diff --git a/src/ge/opskernel_manager/ops_kernel_manager.cc b/src/ge/opskernel_manager/ops_kernel_manager.cc index 0d6f1e07..24c5a52d 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.cc +++ b/src/ge/opskernel_manager/ops_kernel_manager.cc @@ -73,7 +73,7 @@ Status OpsKernelManager::Initialize(const map &options_const) { options.emplace("ge.exec.isUseHvd", to_string(0)); } - GetExternalEnginePath(extern_engine_path); + GetExternalEnginePath(extern_engine_path, options); GELOGI("OPTION_EXEC_EXTERN_PLUGIN_PATH=%s.", extern_engine_path.c_str()); op_tiling_manager_.LoadSo(); @@ -123,7 +123,7 @@ Status OpsKernelManager::Initialize(const map &options_const) { } } -void OpsKernelManager::GetExternalEnginePath(std::string &extern_engine_path) { +void OpsKernelManager::GetExternalEnginePath(std::string &extern_engine_path, const std::map &options) { GELOGI("Enter get external engine so path schedule"); const char *path_env = std::getenv("ASCEND_ENGINE_PATH"); if (path_env != nullptr) { @@ -136,7 +136,11 @@ void OpsKernelManager::GetExternalEnginePath(std::string &extern_engine_path) { std::string path = path_base + so_path; extern_engine_path = (path + "libfe.so" + ":") + (path + "libge_local_engine.so" + ":") + (path + "librts_engine.so" + ":") + (path + "libaicpu_engine.so" + ":") + - (path_base + "libhcom_graph_adaptor.so"); + (path + "libhost_cpu_engine.so" + ":"); + auto iter = options.find(OPTION_EXEC_HCCL_FLAG); + if (iter == options.end() || iter->second != "0") { + extern_engine_path += (path_base + "libhcom_graph_adaptor.so"); + } } Status OpsKernelManager::InitPluginOptions(const map &options) { diff --git a/src/ge/opskernel_manager/ops_kernel_manager.h b/src/ge/opskernel_manager/ops_kernel_manager.h index 1d464201..43644d0e 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.h +++ b/src/ge/opskernel_manager/ops_kernel_manager.h @@ -91,7 +91,7 @@ class OpsKernelManager { Status CheckPluginPtr(); - void GetExternalEnginePath(std::string &path); + void GetExternalEnginePath(std::string &path, const std::map &options); void InitOpsKernelInfo(); diff --git a/src/ge/opskernel_manager/optimizer_priority.pbtxt b/src/ge/opskernel_manager/optimizer_priority.pbtxt old mode 100644 new mode 100755 index 76768817..9f8a03fb --- a/src/ge/opskernel_manager/optimizer_priority.pbtxt +++ b/src/ge/opskernel_manager/optimizer_priority.pbtxt @@ -1 +1 @@ -optimizer:["aicpu_original_optimizer","AIcoreEngine","VectorEngine","aicpu_optimizer","hccl_graph_optimizer", "hvd_graph_optimizer"] \ No newline at end of file +optimizer:["aicpu_original_optimizer","AIcoreEngine","VectorEngine","aicpu_optimizer","hccl_graph_optimizer", "hvd_graph_optimizer", "DNN_VM_RTS_GRAPH_OPTIMIZER_STORE"] \ No newline at end of file diff --git a/src/ge/plugin/engine/dnnengines.cc b/src/ge/plugin/engine/dnnengines.cc index e75fb74b..d85d1668 100644 --- a/src/ge/plugin/engine/dnnengines.cc +++ b/src/ge/plugin/engine/dnnengines.cc @@ -61,7 +61,7 @@ AICpuDNNEngine::AICpuDNNEngine(const std::string &engine_name) { engine_attribute_.engine_output_format = FORMAT_RESERVED; } -AICpuDNNEngine::AICpuDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } +AICpuDNNEngine::AICpuDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } Status AICpuDNNEngine::Initialize(const std::map &options) { return SUCCESS; } @@ -83,6 +83,22 @@ Status GeLocalDNNEngine::Finalize() { return SUCCESS; } void GeLocalDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } +HostCpuDNNEngine::HostCpuDNNEngine(const std::string &engine_name) { + engine_attribute_.engine_name = engine_name; + engine_attribute_.compute_cost = COST_10; + engine_attribute_.runtime_type = HOST; + engine_attribute_.engine_input_format = FORMAT_RESERVED; + engine_attribute_.engine_output_format = FORMAT_RESERVED; +} + +HostCpuDNNEngine::HostCpuDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } + +Status HostCpuDNNEngine::Initialize(const std::map &options) { return SUCCESS; } + +Status HostCpuDNNEngine::Finalize() { return SUCCESS; } + +void HostCpuDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } + RtsDNNEngine::RtsDNNEngine(const std::string &engine_name) { engine_attribute_.engine_name = engine_name; engine_attribute_.engine_input_format = FORMAT_RESERVED; diff --git a/src/ge/plugin/engine/dnnengines.h b/src/ge/plugin/engine/dnnengines.h index 6f669cc9..d776c2b9 100644 --- a/src/ge/plugin/engine/dnnengines.h +++ b/src/ge/plugin/engine/dnnengines.h @@ -55,7 +55,6 @@ class VectorCoreDNNEngine : public DNNEngine { DNNEngineAttribute engine_attribute_; }; - class AICpuDNNEngine : public DNNEngine { public: AICpuDNNEngine() = default; @@ -86,6 +85,21 @@ class GeLocalDNNEngine : public DNNEngine { DNNEngineAttribute engine_attribute_; }; +class HostCpuDNNEngine : public DNNEngine { + public: + HostCpuDNNEngine() = default; + explicit HostCpuDNNEngine(const std::string &engine_name); + explicit HostCpuDNNEngine(const DNNEngineAttribute &attrs); + ~HostCpuDNNEngine() = default; + + Status Initialize(const std::map &options); + Status Finalize(); + void GetAttributes(DNNEngineAttribute &attr) const; + + private: + DNNEngineAttribute engine_attribute_; +}; + class RtsDNNEngine : public DNNEngine { public: RtsDNNEngine() = default; diff --git a/src/ge/plugin/engine/engine_manage.cc b/src/ge/plugin/engine/engine_manage.cc index d29c3ac7..82cd90ee 100644 --- a/src/ge/plugin/engine/engine_manage.cc +++ b/src/ge/plugin/engine/engine_manage.cc @@ -57,7 +57,7 @@ DNNEnginePtr EngineManager::GetEngine(const std::string &engine_name) { return engine; } -void GetDNNEngineObjs(std::map &engines) { +void RegisterAiCoreEngine() { const std::string ai_core = "AIcoreEngine"; std::vector mem_type_aicore; mem_type_aicore.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); @@ -70,7 +70,9 @@ void GetDNNEngineObjs(std::map &engines) { if (EngineManager::RegisterEngine(ai_core, aicore_engine_ptr) != SUCCESS) { GELOGW("register ai_core failed"); } +} +void RegisterVectorEngine() { const std::string vector_core = "VectorEngine"; std::vector mem_type_aivcore; mem_type_aivcore.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); @@ -81,11 +83,12 @@ void GetDNNEngineObjs(std::map &engines) { GELOGE(ge::FAILED, "make vectorCoreEnginePtr failed"); return; } - if (EngineManager::RegisterEngine(vector_core, vectorcore_engine_ptr) != SUCCESS) { GELOGW("register vector_core failed"); } +} +void RegisterAiCpuEngine() { const std::string vm_aicpu = "DNN_VM_AICPU"; std::vector mem_type_aicpu; mem_type_aicpu.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); @@ -98,7 +101,9 @@ void GetDNNEngineObjs(std::map &engines) { if (EngineManager::RegisterEngine(vm_aicpu, vm_engine_ptr) != SUCCESS) { GELOGW("register vmAicpuEngine failed"); } +} +void RegisterGeLocalEngine() { const std::string vm_ge_local = "DNN_VM_GE_LOCAL"; std::vector mem_type_ge_local; mem_type_ge_local.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); @@ -112,7 +117,25 @@ void GetDNNEngineObjs(std::map &engines) { if (EngineManager::RegisterEngine(vm_ge_local, ge_local_engine) != SUCCESS) { GELOGW("register ge_local_engine failed"); } +} +void RegisterHostCpuEngine() { + const std::string vm_host_cpu = "DNN_VM_HOST_CPU"; + std::vector mem_type_host_cpu; + mem_type_host_cpu.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); + // HostCpu use minimum priority, set it as 10 + DNNEngineAttribute attr_host_cpu = {vm_host_cpu, mem_type_host_cpu, COST_10, HOST, FORMAT_RESERVED, FORMAT_RESERVED}; + DNNEnginePtr host_cpu_engine = MakeShared(attr_host_cpu); + if (host_cpu_engine == nullptr) { + GELOGE(ge::FAILED, "make host_cpu_engine failed"); + return; + } + if (EngineManager::RegisterEngine(vm_host_cpu, host_cpu_engine) != SUCCESS) { + GELOGW("register host_cpu_engine failed"); + } +} + +void RegisterRtsEngine() { const std::string vm_rts = "DNN_VM_RTS"; std::vector mem_type_rts; mem_type_rts.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); @@ -125,7 +148,9 @@ void GetDNNEngineObjs(std::map &engines) { if (EngineManager::RegisterEngine(vm_rts, rts_engine) != SUCCESS) { GELOGW("register rts_engine failed"); } +} +void RegisterHcclEngine() { const std::string dnn_hccl = "DNN_HCCL"; std::vector mem_type_hccl; mem_type_hccl.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); @@ -138,6 +163,16 @@ void GetDNNEngineObjs(std::map &engines) { if (EngineManager::RegisterEngine(dnn_hccl, hccl_engine) != SUCCESS) { GELOGW("register hccl_engine failed"); } +} + +void GetDNNEngineObjs(std::map &engines) { + RegisterAiCoreEngine(); + RegisterVectorEngine(); + RegisterAiCpuEngine(); + RegisterGeLocalEngine(); + RegisterHostCpuEngine(); + RegisterRtsEngine(); + RegisterHcclEngine(); for (auto it = EngineManager::engine_map_->begin(); it != EngineManager::engine_map_->end(); ++it) { GELOGI("get engine %s from engine plugin.", it->first.c_str()); diff --git a/src/ge/session/inner_session.cc b/src/ge/session/inner_session.cc index b97862e1..a4e77b73 100644 --- a/src/ge/session/inner_session.cc +++ b/src/ge/session/inner_session.cc @@ -295,4 +295,17 @@ bool InnerSession::IsGraphNeedRebuild(uint32_t graph_id) { UpdateThreadContext(graph_id); return graph_manager_.IsGraphNeedRebuild(graph_id); } + +Status InnerSession::GetAllVariables(std::map &all_variables) { + return VarManager::Instance(session_id_)->GetAllVariables(all_variables); +} + +Status InnerSession::GenCheckPointGraph(const std::map &all_variables, Graph &graph) { + return graph_manager_.GenCheckPointGraph(all_variables, graph); +} + +Status InnerSession::SaveVariables(const Graph &graph, const std::vector &var_names, + const std::vector &outputs, std::vector &var_values) { + return graph_manager_.SaveVariables(graph, var_names, outputs, var_values); +} } // namespace ge diff --git a/src/ge/session/inner_session.h b/src/ge/session/inner_session.h index bcc47354..3d9bf39f 100644 --- a/src/ge/session/inner_session.h +++ b/src/ge/session/inner_session.h @@ -47,6 +47,13 @@ class InnerSession { Status Finalize(); + Status GetAllVariables(std::map &all_variables); + + Status GenCheckPointGraph(const std::map &all_variables, Graph &graph); + + Status SaveVariables(const Graph &graph, const std::vector &var_names, + const std::vector &outputs, std::vector &var_values); + Status GetVariable(const std::string &name, Tensor &val); Status RegisterCallBackFunc( diff --git a/src/ge/session/omg.cc b/src/ge/session/omg.cc index 55075d6a..805f8653 100644 --- a/src/ge/session/omg.cc +++ b/src/ge/session/omg.cc @@ -25,7 +25,6 @@ #include "common/ge/ge_util.h" #include "common/helper/model_helper.h" #include "common/model_parser/base.h" -#include "common/model_parser/graph_parser_util.h" #include "common/model_saver.h" #include "common/properties_manager.h" #include "common/string_util.h" @@ -116,13 +115,9 @@ static Status CheckInputShapeNode(const ComputeGraphPtr &graph) { void AddAttrsForInputNodes(const vector &adjust_fp16_format_vec, const string &fp16_nodes_name, uint32_t index, OpDescPtr &op_desc) { - if (AttrUtils::SetBool(op_desc, "input_fp16", true) && - AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { + if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) { GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str()); - if (!AttrUtils::SetBool(op_desc, "input_set_nc1hwc0", true)) { - GELOGW("This node [%s] set NC1HWC0 failed", fp16_nodes_name.c_str()); - } if (!AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_FORMAT, TypeUtils::FormatToSerialString(FORMAT_NC1HWC0))) { GELOGW("This node [%s] set NC1HWC0 failed", fp16_nodes_name.c_str()); } @@ -211,6 +206,30 @@ static Status SetWeightCompressNodes(const ComputeGraphPtr &graph, const string return SUCCESS; } +static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { + if (is_output_fp16.empty()) { + return SUCCESS; + } + + vector &output_formats = domi::GetContext().output_formats; + output_formats.clear(); + vector node_format_vec = StringUtils::Split(is_output_fp16, ','); + for (auto &is_fp16 : node_format_vec) { + StringUtils::Trim(is_fp16); + if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) { + GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]", + is_output_fp16.c_str()); + return PARAM_INVALID; + } + if (is_fp16 == "false") { + output_formats.push_back(DOMI_TENSOR_ND); + } else if (is_fp16 == "true") { + output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0); + } + } + return SUCCESS; +} + void FindParserSo(const string &path, vector &file_list, string &caffe_parser_path) { // path, Change to absolute path string real_path = RealPath(path.c_str()); @@ -292,6 +311,192 @@ Status SetOutFormatAndDataTypeAttr(ge::OpDescPtr op_desc, const ge::Format forma return domi::SUCCESS; } +bool CheckDigitStr(std::string &str) { + for (char c : str) { + if (!isdigit(c)) { + GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str()); + return false; + } + } + return true; +} + +Status StringToInt(std::string &str, int32_t &value) { + try { + if (!CheckDigitStr(str)) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", str, "is not positive integer"}); + return PARAM_INVALID; + } + value = stoi(str); + } catch (std::invalid_argument &) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"--output_type", str}); + return PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"--output_type", str}); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status VerifyOutputTypeAndOutNodes(std::vector &out_type_vec) { + std::vector> user_out_nodes = domi::GetContext().user_out_nodes; + std::set out_nodes_info; + for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { + // out_nodes set should include output_type and output_format + std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second); + out_nodes_info.emplace(tmp); + } + for (uint32_t i = 0; i < out_type_vec.size(); ++i) { + if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", out_type_vec[i], kOutputTypeError}); + GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); + return domi::FAILED; + } + } + return domi::SUCCESS; +} + +Status CheckOutPutDataTypeSupport(const std::string &output_type) { + auto it = output_type_str_to_datatype.find(output_type); + if (it == output_type_str_to_datatype.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", output_type, kOutputTypeSupport}); + GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); + return domi::FAILED; + } + return domi::SUCCESS; +} + +Status ParseOutputType(const std::string &output_type, std::map> &output_node_dt_map) { + if (output_type.find(':') == std::string::npos) { + GELOGI("output_type is not multiple nodes, means all out nodes"); + return CheckOutPutDataTypeSupport(output_type); + } + std::vector out_type_vec; + vector nodes_v = StringUtils::Split(output_type, ';'); + for (const string &node : nodes_v) { + vector node_index_type_v = StringUtils::Split(node, ':'); + if (node_index_type_v.size() != 3) { // The size must be 3. + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", node, kOutputTypeSample}); + GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample); + return domi::FAILED; + } + ge::DataType tmp_dt; + std::string node_name = StringUtils::Trim(node_index_type_v[0]); + std::string index_str = StringUtils::Trim(node_index_type_v[1]); + int32_t index; + if (StringToInt(index_str, index) != SUCCESS) { + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); + return domi::FAILED; + } + std::string dt_value = StringUtils::Trim(node_index_type_v[2]); + auto it = output_type_str_to_datatype.find(dt_value); + if (it == output_type_str_to_datatype.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", dt_value, kOutputTypeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport); + return domi::FAILED; + } else { + tmp_dt = it->second; + } + out_type_vec.push_back(node_name + ":" + index_str); + std::string index_dt_str = index_str + ":" + TypeUtils::DataTypeToSerialString(tmp_dt); + auto it1 = output_node_dt_map.find(node_name); + if (it1 == output_node_dt_map.end()) { + vector tmp_vec; + tmp_vec.push_back(index_dt_str); + output_node_dt_map.emplace(node_name, tmp_vec); + } else { + it1->second.push_back(index_dt_str); + } + } + return VerifyOutputTypeAndOutNodes(out_type_vec); +} + +Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { + int32_t out_size = op_desc->GetOutputsSize(); + if (index < 0 || index >= out_size) { + GELOGE(domi::FAILED, + "out_node [%s] output index:%d must be smaller " + "than node output size:%d and can not be negative!", + op_desc->GetName().c_str(), index, out_size); + std::string fail_reason = "output index:" + to_string(index) + + " must be smaller than output size:" + to_string(out_size) + " and can not be negative!"; + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, + {"out_nodes", op_desc->GetName(), fail_reason}); + return domi::FAILED; + } + return domi::SUCCESS; +} + +Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output) { + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + std::vector> user_out_nodes = domi::GetContext().user_out_nodes; + std::vector output_formats = domi::GetContext().output_formats; + std::vector> output_nodes_info; + std::vector output_nodes_name; + std::map> output_node_dt_map; + if (!output_type.empty()) { + if (ParseOutputType(output_type, output_node_dt_map) != SUCCESS) { + GELOGE(domi::FAILED, "Parse output_type failed."); + return domi::FAILED; + } + } + + // User declared outputs + for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { + ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first); + if (out_node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"out_nodes", user_out_nodes[i].first}); + GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } + auto op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) { + GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } + if (i < output_formats.size()) { + if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { + GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); + vector output_fp16_5hd_vec; + (void)ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); + output_fp16_5hd_vec.push_back(std::to_string(user_out_nodes[i].second) + ":" + "NC1HWC0"); + (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); + } + } + auto it = output_node_dt_map.find(user_out_nodes[i].first); + if (it != output_node_dt_map.end()) { + GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str()); + (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_data_type", it->second); + } + output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); + } + // default output node (leaf) + if (user_out_nodes.empty()) { + for (ge::NodePtr node : compute_graph->GetDirectNode()) { + if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) { + Status ret = GetOutputLeaf(node, output_nodes_info); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail."); + } + } + } + GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); + compute_graph->SetGraphOutNodesInfo(output_nodes_info); + domi::GetContext().net_out_nodes = output_nodes_name; + return domi::SUCCESS; +} + void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, std::vector &output_nodes_name) { output_nodes_name.clear(); @@ -317,6 +522,32 @@ void GetOutputNodesNameAndIndex(std::vector> &ou } } +Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info) { + ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); + if (tmpDescPtr == nullptr) { + GELOGE(domi::FAILED, "Get outnode op desc fail."); + return domi::FAILED; + } + size_t size = tmpDescPtr->GetOutputsSize(); + if (node->GetType() != NETOUTPUT) { + for (size_t index = 0; index < size; ++index) { + output_nodes_info.push_back(std::make_pair(node, index)); + } + } else { + const auto in_anchors = node->GetAllInDataAnchors(); + for (auto in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGE(domi::FAILED, "Get leaf node op desc fail."); + return domi::FAILED; + } + auto out_node = out_anchor->GetOwnerNode(); + output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); + } + } + return SUCCESS; +} + /// /// @ingroup domi_common /// @brief Initialize omgcontext based on command line input @@ -360,6 +591,57 @@ Status InitDomiOmgContext(const string &input_shape, const string &input_format, return SUCCESS; } +Status ParseOutNodes(const string &out_nodes) { + try { + // parse output node + if (!out_nodes.empty()) { + domi::GetContext().out_nodes_map.clear(); + domi::GetContext().user_out_nodes.clear(); + + vector nodes_v = StringUtils::Split(out_nodes, ';'); + for (const string &node : nodes_v) { + vector key_value_v = StringUtils::Split(node, ':'); + if (key_value_v.size() != 2) { // The size must be 2. + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""}); + GELOGE(PARAM_INVALID, + "The input format of --out_nodes is invalid, the correct format is " + "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", + node.c_str()); + return PARAM_INVALID; + } + auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); + // stoi: The method may throw an exception: invalid_argument/out_of_range + if (!CheckDigitStr(key_value_v[1])) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--out_nodes", out_nodes, "is not positive integer"}); + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str()); + return PARAM_INVALID; + } + int32_t index = stoi(StringUtils::Trim(key_value_v[1])); + if (iter != domi::GetContext().out_nodes_map.end()) { + iter->second.emplace_back(index); + } else { + std::vector index_v; + index_v.emplace_back(index); + domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v); + } + domi::GetContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); + } + } + } catch (std::invalid_argument &) { + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"--out_nodes", out_nodes}); + return PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"--out_nodes", out_nodes}); + return PARAM_INVALID; + } + return SUCCESS; +} + /// @ingroup domi_common /// @brief Judge whether the op_Name_Map parameter matches the network /// @param [in] graph Input network graph @@ -377,7 +659,12 @@ static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op graphNodeTypes[op_desc->GetType()] = ""; } std::map &propertiesMap = domi::GetContext().op_conf_map; - GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(propertiesMap.empty(), "op_name_map file is empty, please check file!"); + if (propertiesMap.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, + {"op_name_map", op_conf, "the file content is empty"}); + GELOGE(PARAM_INVALID, "op_name_map file content is empty, please check file!"); + return PARAM_INVALID; + } for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) { GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(), ErrorManager::GetInstance().ATCReportErrMessage( @@ -428,7 +715,7 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::mapRunGraphAsync(graph_id, inputs, callback); } + +Status SessionManager::GetVariables(SessionId session_id, const std::vector &var_names, + std::vector &var_values) { + // step 0: init session manager + if (!init_flag_) { + GELOGE(GE_SESSION_MANAGER_NOT_INIT); + return GE_SESSION_MANAGER_NOT_INIT; + } + SessionPtr innerSession = nullptr; + { + std::lock_guard lock(mutex_); + std::map::iterator it = session_manager_map_.find(session_id); + if (it == session_manager_map_.end()) { + return GE_SESSION_NOT_EXIST; + } else { + innerSession = it->second; + } + } + + // step 1: get all variable + std::map all_variables; + Status ret = innerSession->GetAllVariables(all_variables); + if (ret != SUCCESS) { + GELOGE(FAILED, "Get all variables failed."); + return FAILED; + } + + // srep 2: create check point graph + Graph graph = Graph("checkpoint"); + ret = innerSession->GenCheckPointGraph(all_variables, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Build check point graph failed."); + return FAILED; + } + + // step 3: run check point graph + uint32_t graph_id = GetCurrentSecondTimestap(); + ret = AddGraph(session_id, graph_id, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Add check point graph failed."); + return FAILED; + } + + vector inputs; + vector outputs; + ret = RunGraph(session_id, graph_id, inputs, outputs); + if (ret != SUCCESS) { + GELOGE(FAILED, "Run check point graph failed."); + return FAILED; + } + + // step 4: save variables + ret = innerSession->SaveVariables(graph, var_names, outputs, var_values); + GELOGD("[SessionManager] outputs size is [%zu], var values size is [%zu].", outputs.size(), var_values.size()); + + if (ret != SUCCESS) { + GELOGE(FAILED, "Save variables failed."); + return FAILED; + } + return ret; +} + bool SessionManager::IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id) { if (!init_flag_) { GELOGE(GE_SESSION_MANAGER_NOT_INIT); diff --git a/src/ge/session/session_manager.h b/src/ge/session/session_manager.h index 5cdb849f..1efb47d8 100644 --- a/src/ge/session/session_manager.h +++ b/src/ge/session/session_manager.h @@ -124,6 +124,16 @@ class SessionManager { Status RunGraphAsync(SessionId session_id, uint32_t graph_id, const std::vector &inputs, RunAsyncCallback callback); + /// + /// @ingroup ge_graph + /// @brief get variables in the session with specific session id + /// @param [in] session_id: sssion id + /// @param [in] var_names: variable names + /// @param [out] var_values: variable values + /// @return Status result of function + /// + Status GetVariables(SessionId session_id, const std::vector &var_names, std::vector &var_values); + /// /// @ingroup ge_graph /// @brief me register the callback function to get the result of summary or checkpoin diff --git a/src/ge/single_op/single_op.cc b/src/ge/single_op/single_op.cc index 1a63c964..5fa4efcf 100644 --- a/src/ge/single_op/single_op.cc +++ b/src/ge/single_op/single_op.cc @@ -34,6 +34,9 @@ size_t GetAlignedSize(uint32_t size) { return aligned_size; } } // namespace + +SingleOp::SingleOp(std::mutex *stream_mutex, rtStream_t stream) : stream_mutex_(stream_mutex), stream_(stream) {} + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY SingleOp::~SingleOp() { for (auto task : tasks_) { delete task; @@ -52,7 +55,7 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: for (size_t i = 0; i < num_inputs; ++i) { // preventing from read out of bound size_t aligned_size = GetAlignedSize(inputs[i].length); - GELOGI("Input [%zu], aligned_size:%zu, inputs.length:%u, input_sizes_:%u", i, aligned_size, inputs[i].length, + GELOGI("Input [%zu], aligned_size:%zu, inputs.length:%lu, input_sizes_:%lu", i, aligned_size, inputs[i].length, input_sizes_[i]); if (aligned_size < input_sizes_[i]) { GELOGE(PARAM_INVALID, @@ -72,7 +75,7 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: for (size_t i = 0; i < num_outputs; ++i) { // preventing from write out of bound size_t aligned_size = GetAlignedSize(outputs[i].length); - GELOGI("Output [%zu], aligned_size:%zu, outputs.length:%u, output_sizes_:%u", i, aligned_size, outputs[i].length, + GELOGI("Output [%zu], aligned_size:%zu, outputs.length:%lu, output_sizes_:%lu", i, aligned_size, outputs[i].length, output_sizes_[i]); if (aligned_size < output_sizes_[i]) { GELOGE(PARAM_INVALID, @@ -133,9 +136,7 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve size_t io_addr_num = args_.size(); if (task->GetOpTaskType() == OP_TASK_AICPU) { GELOGD("Update aicpu_TF task args"); - AiCpuTask *task_aicpu = dynamic_cast(task); - GE_CHECK_NOTNULL(task_aicpu); - auto *dst_io_addr = const_cast(reinterpret_cast(task_aicpu->GetIOAddr())); + auto *dst_io_addr = const_cast(reinterpret_cast(task->GetIOAddr())); GE_CHECK_NOTNULL(dst_io_addr); auto rt_ret = rtMemcpyAsync(dst_io_addr, sizeof(uint64_t) * args_.size(), &args_[0], sizeof(uint64_t) * args_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream_); @@ -145,9 +146,7 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve } } else if (task->GetOpTaskType() == OP_TASK_AICPUCC) { GELOGD("Update aicpu_CC task args"); - AiCpuCCTask *task_aicpu_cc = dynamic_cast(task); - GE_CHECK_NOTNULL(task_aicpu_cc); - const uintptr_t *task_io_addr = reinterpret_cast(task_aicpu_cc->GetIOAddr()); + const uintptr_t *task_io_addr = reinterpret_cast(task->GetIOAddr()); GE_CHECK_NOTNULL(task_io_addr); auto io_addr = reinterpret_cast(const_cast(task_io_addr)); for (size_t i = 0; i < io_addr_num; ++i) { @@ -168,6 +167,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c return ret; } + std::lock_guard lk(*stream_mutex_); ret = UpdateArgs(inputs, outputs); if (ret != SUCCESS) { return ret; @@ -185,8 +185,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c void SingleOp::SetStream(rtStream_t stream) { stream_ = stream; } -DynamicSingleOp::DynamicSingleOp(uintptr_t resource_id, rtStream_t stream) - : resource_id_(resource_id), stream_(stream) {} +DynamicSingleOp::DynamicSingleOp(uintptr_t resource_id, std::mutex *stream_mutex, rtStream_t stream) + : resource_id_(resource_id), stream_mutex_(stream_mutex), stream_(stream) {} Status DynamicSingleOp::ValidateParams(const vector &input_desc, const std::vector &inputs, std::vector &output_desc, std::vector &outputs) const { @@ -252,6 +252,7 @@ Status DynamicSingleOp::ExecuteAsync(const vector &input_desc, con vector &output_desc, vector &output_buffers) { GE_CHECK_NOTNULL(op_task_); GE_CHK_STATUS_RET_NOLOG(ValidateParams(input_desc, input_buffers, output_desc, output_buffers)); + std::lock_guard lk(*stream_mutex_); GE_CHK_STATUS_RET_NOLOG(op_task_->UpdateRunInfo(input_desc, output_desc)); std::vector workspace_buffers; GE_CHK_STATUS_RET_NOLOG(AllocateWorkspaces(op_task_->GetWorkspaceSizes(), workspace_buffers)); diff --git a/src/ge/single_op/single_op.h b/src/ge/single_op/single_op.h index d86c79ee..71096f35 100644 --- a/src/ge/single_op/single_op.h +++ b/src/ge/single_op/single_op.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -30,7 +31,7 @@ namespace ge { class SingleOp { public: - SingleOp() = default; + SingleOp(std::mutex *stream_mutex, rtStream_t stream); ~SingleOp(); Status ExecuteAsync(const std::vector &inputs, const std::vector &outputs); @@ -42,6 +43,7 @@ class SingleOp { Status GetArgs(const std::vector &inputs, const std::vector &outputs); friend class SingleOpModel; + std::mutex *stream_mutex_; rtStream_t stream_ = nullptr; std::vector input_addr_list_; std::vector input_sizes_; @@ -56,7 +58,7 @@ class SingleOp { class DynamicSingleOp { public: - DynamicSingleOp(uintptr_t resource_id, rtStream_t stream); + DynamicSingleOp(uintptr_t resource_id, std::mutex *stream_mutex_, rtStream_t stream); ~DynamicSingleOp() = default; Status ExecuteAsync(const vector &input_desc, const std::vector &inputs, std::vector &output_desc, std::vector &outputs); @@ -70,6 +72,7 @@ class DynamicSingleOp { std::unique_ptr op_task_; uintptr_t resource_id_ = 0; + std::mutex *stream_mutex_; rtStream_t stream_ = nullptr; size_t num_inputs_ = 0; size_t num_outputs_ = 0; diff --git a/src/ge/single_op/single_op_manager.cc b/src/ge/single_op/single_op_manager.cc index aa6f6d2b..709b238f 100644 --- a/src/ge/single_op/single_op_manager.cc +++ b/src/ge/single_op/single_op_manager.cc @@ -52,27 +52,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::GetOpFr return SUCCESS; } - SingleOpModel model(model_name, model_data.model_data, model_data.model_len); - auto ret = model.Init(); - if (ret != SUCCESS) { - GELOGE(ret, "Init model failed. model = %s, ret = %u", model_name.c_str(), ret); - return ret; - } - - auto new_op = std::unique_ptr(new (std::nothrow) SingleOp()); - if (new_op == nullptr) { - GELOGE(MEMALLOC_FAILED, "new SingleOp failed"); - return MEMALLOC_FAILED; - } - - GELOGI("To build operator: %s", model_name.c_str()); - GE_CHK_STATUS_RET(model.BuildOp(*res, *new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); - - // stream is nullable - new_op->SetStream(stream); - *single_op = new_op.get(); - res->CacheOperator(model_data.model_data, std::move(new_op)); - return SUCCESS; + return res->BuildOperator(model_name, model_data, single_op); } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::ReleaseResource(void *stream) { @@ -94,7 +74,7 @@ StreamResource *SingleOpManager::GetResource(uintptr_t resource_id, rtStream_t s auto it = stream_resources_.find(resource_id); StreamResource *res = nullptr; if (it == stream_resources_.end()) { - res = new (std::nothrow) StreamResource(); + res = new (std::nothrow) StreamResource(resource_id); if (res != nullptr) { res->SetStream(stream); stream_resources_.emplace(resource_id, res); @@ -118,6 +98,10 @@ StreamResource *SingleOpManager::TryGetResource(uintptr_t resource_id) { Status SingleOpManager::GetDynamicOpFromModel(const string &model_name, const ModelData &model_data, void *stream, DynamicSingleOp **single_op) { + if (!tiling_func_registered_) { + RegisterTilingFunc(); + } + GE_CHECK_NOTNULL(single_op); uintptr_t resource_id = 0; GE_CHK_STATUS_RET(GetResourceId(stream, resource_id)); @@ -134,25 +118,7 @@ Status SingleOpManager::GetDynamicOpFromModel(const string &model_name, const Mo return SUCCESS; } - if (!tiling_func_registered_) { - RegisterTilingFunc(); - } - - SingleOpModel model(model_name, model_data.model_data, model_data.model_len); - auto ret = model.Init(); - if (ret != SUCCESS) { - GELOGE(ret, "Init model failed. model = %s, ret = %u", model_name.c_str(), ret); - return ret; - } - - auto new_op = std::unique_ptr(new (std::nothrow) DynamicSingleOp(resource_id, stream)); - GE_CHECK_NOTNULL(new_op); - - GELOGI("To build operator: %s", model_name.c_str()); - GE_CHK_STATUS_RET(model.BuildDynamicOp(*new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); - *single_op = new_op.get(); - res->CacheDynamicOperator(model_data.model_data, std::move(new_op)); - return SUCCESS; + return res->BuildDynamicOperator(model_name, model_data, single_op); } void SingleOpManager::RegisterTilingFunc() { diff --git a/src/ge/single_op/single_op_model.cc b/src/ge/single_op/single_op_model.cc index 27958e7c..65f76acc 100644 --- a/src/ge/single_op/single_op_model.cc +++ b/src/ge/single_op/single_op_model.cc @@ -100,13 +100,18 @@ Status SingleOpModel::InitModelMem(StreamResource &res) { } } - if (model_params_.weight_size > 0) { + if (model_params_.weight_size > 0 && has_weight_) { const string purpose("malloc weights memory on model execute."); model_params_.weight_base = res.MallocWeight(purpose, model_params_.weight_size); if (model_params_.weight_base == nullptr) { // no need to free memory, for that was handled by StreamResources return RT_FAILED; } + + auto weight_buffer = model_helper_.GetGeModel()->GetWeight(); + GELOGI("To copy weight to device. weight size = %zu", weight_buffer.GetSize()); + GE_CHK_RT_RET(rtMemcpy(model_params_.weight_base, model_params_.weight_size, weight_buffer.GetData(), + weight_buffer.GetSize(), RT_MEMCPY_HOST_TO_DEVICE)); } return SUCCESS; @@ -173,6 +178,11 @@ Status SingleOpModel::LoadAllNodes() { continue; } + if (op_type == CONSTANT || op_type == CONSTANTOP) { + has_weight_ = true; + continue; + } + if (op_type == NETOUTPUT) { netoutput_op_ = op_desc; continue; @@ -341,13 +351,19 @@ Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, Sin } Status SingleOpModel::BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task) { + const auto &context = kernel_def.context(); + auto iter = op_list_.find(context.op_index()); + if (iter == op_list_.end()) { + GELOGE(INTERNAL_ERROR, "op desc not found. op index = %u", context.op_index()); + return INTERNAL_ERROR; + } std::unique_ptr aicpucc_task(new (std::nothrow) AiCpuCCTask()); if (aicpucc_task == nullptr) { GELOGE(MEMALLOC_FAILED, "create aicpu_CC op task failed"); return MEMALLOC_FAILED; } - auto builder = AiCpuCCTaskBuilder(kernel_def); + auto builder = AiCpuCCTaskBuilder(iter->second->GetOpDesc(), kernel_def); auto ret = builder.BuildTask(*aicpucc_task); if (ret != SUCCESS) { GELOGE(ret, "build aicpu_CC op task failed"); diff --git a/src/ge/single_op/single_op_model.h b/src/ge/single_op/single_op_model.h index caa958e5..8becf438 100644 --- a/src/ge/single_op/single_op_model.h +++ b/src/ge/single_op/single_op_model.h @@ -87,6 +87,7 @@ class SingleOpModel { std::vector output_sizes_; std::vector data_ops_; OpDescPtr netoutput_op_; + bool has_weight_ = false; }; } // namespace ge diff --git a/src/ge/single_op/stream_resource.cc b/src/ge/single_op/stream_resource.cc index 703b22b2..c2b93974 100644 --- a/src/ge/single_op/stream_resource.cc +++ b/src/ge/single_op/stream_resource.cc @@ -16,12 +16,14 @@ #include "single_op/stream_resource.h" -#include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "runtime/rt.h" +#include "single_op/single_op_model.h" namespace ge { +StreamResource::StreamResource(uintptr_t resource_id) : resource_id_(resource_id) {} + StreamResource::~StreamResource() { for (auto mem : memory_list_) { if (mem != nullptr) { @@ -38,15 +40,8 @@ StreamResource::~StreamResource() { } } -void StreamResource::CacheOperator(const void *key, std::unique_ptr &&single_op) { - op_map_[key] = std::move(single_op); -} - -void StreamResource::CacheDynamicOperator(const void *key, std::unique_ptr &&single_op) { - dynamic_op_map_[key] = std::move(single_op); -} - SingleOp *StreamResource::GetOperator(const void *key) { + std::lock_guard lk(mu_); auto it = op_map_.find(key); if (it == op_map_.end()) { return nullptr; @@ -56,6 +51,7 @@ SingleOp *StreamResource::GetOperator(const void *key) { } DynamicSingleOp *StreamResource::GetDynamicOperator(const void *key) { + std::lock_guard lk(mu_); auto it = dynamic_op_map_.find(key); if (it == dynamic_op_map_.end()) { return nullptr; @@ -73,20 +69,6 @@ uint8_t *StreamResource::DoMallocMemory(const std::string &purpose, size_t size, return allocated.back(); } - if (!allocated.empty()) { - GELOGD("Expand workspace memory size from %zu to %zu", max_allocated, size); - auto ret = rtStreamSynchronize(stream_); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtStreamSynchronize failed, ret = %d", ret); - return nullptr; - } - - auto addr = allocated.back(); - allocated.pop_back(); - (void)rtFree(addr); - max_allocated = 0; - } - uint8_t *buffer = nullptr; auto ret = rtMalloc(reinterpret_cast(&buffer), size, RT_MEMORY_HBM); if (ret != RT_ERROR_NONE) { @@ -117,7 +99,71 @@ uint8_t *StreamResource::MallocMemory(const std::string &purpose, size_t size) { uint8_t *StreamResource::MallocWeight(const std::string &purpose, size_t size) { GELOGD("To Malloc weight, size = %zu", size); - uint8_t *buffer = DoMallocMemory(purpose, size, max_weight_size_, weight_list_); + uint8_t *buffer = nullptr; + auto ret = rtMalloc(reinterpret_cast(&buffer), size, RT_MEMORY_HBM); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMalloc failed, size = %zu, ret = %d", size, ret); + return nullptr; + } + + GE_PRINT_DYNAMIC_MEMORY(rtMalloc, purpose.c_str(), size) + weight_list_.emplace_back(buffer); return buffer; } + +Status StreamResource::BuildDynamicOperator(const string &model_name, const ModelData &model_data, + DynamicSingleOp **single_op) { + std::lock_guard lk(mu_); + auto it = dynamic_op_map_.find(model_data.model_data); + if (it != dynamic_op_map_.end()) { + *single_op = it->second.get(); + return SUCCESS; + } + + SingleOpModel model(model_name, model_data.model_data, model_data.model_len); + auto ret = model.Init(); + if (ret != SUCCESS) { + GELOGE(ret, "Init model failed. model = %s, ret = %u", model_name.c_str(), ret); + return ret; + } + + auto new_op = + std::unique_ptr(new (std::nothrow) DynamicSingleOp(resource_id_, &stream_mu_, stream_)); + GE_CHECK_NOTNULL(new_op); + + GELOGI("To build operator: %s", model_name.c_str()); + GE_CHK_STATUS_RET(model.BuildDynamicOp(*new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); + *single_op = new_op.get(); + dynamic_op_map_[model_data.model_data] = std::move(new_op); + return SUCCESS; +} + +Status StreamResource::BuildOperator(const string &model_name, const ModelData &model_data, SingleOp **single_op) { + std::lock_guard lk(mu_); + auto it = op_map_.find(model_data.model_data); + if (it != op_map_.end()) { + *single_op = it->second.get(); + return SUCCESS; + } + + SingleOpModel model(model_name, model_data.model_data, model_data.model_len); + auto ret = model.Init(); + if (ret != SUCCESS) { + GELOGE(ret, "Init model failed. model = %s, ret = %u", model_name.c_str(), ret); + return ret; + } + + auto new_op = std::unique_ptr(new (std::nothrow) SingleOp(&stream_mu_, stream_)); + if (new_op == nullptr) { + GELOGE(MEMALLOC_FAILED, "new SingleOp failed"); + return MEMALLOC_FAILED; + } + + GELOGI("To build operator: %s", model_name.c_str()); + GE_CHK_STATUS_RET(model.BuildOp(*this, *new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); + + *single_op = new_op.get(); + op_map_[model_data.model_data] = std::move(new_op); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/single_op/stream_resource.h b/src/ge/single_op/stream_resource.h index 6f26c497..3c0dd03f 100644 --- a/src/ge/single_op/stream_resource.h +++ b/src/ge/single_op/stream_resource.h @@ -19,8 +19,9 @@ #include #include -#include +#include #include +#include #include "common/ge_inner_error_codes.h" #include "runtime/stream.h" @@ -29,21 +30,21 @@ namespace ge { class StreamResource { public: - StreamResource() = default; + explicit StreamResource(uintptr_t resource_id); ~StreamResource(); StreamResource(const StreamResource &) = delete; StreamResource(StreamResource &&) = delete; StreamResource &operator=(const StreamResource &) = delete; StreamResource &operator=(StreamResource &&) = delete; - - void CacheOperator(const void *key, std::unique_ptr &&single_op); - void CacheDynamicOperator(const void *key, std::unique_ptr &&single_op); void SetStream(rtStream_t stream); SingleOp *GetOperator(const void *key); DynamicSingleOp *GetDynamicOperator(const void *key); + Status BuildOperator(const std::string &model_name, const ModelData &model_data, SingleOp **single_op); + Status BuildDynamicOperator(const std::string &model_name, const ModelData &model_data, DynamicSingleOp **single_op); + uint8_t *MallocMemory(const std::string &purpose, size_t size); uint8_t *MallocWeight(const std::string &purpose, size_t size); @@ -51,13 +52,15 @@ class StreamResource { uint8_t *DoMallocMemory(const std::string &purpose, size_t size, size_t &max_allocated, std::vector &allocated); + uintptr_t resource_id_; size_t max_memory_size_ = 0; - size_t max_weight_size_ = 0; std::vector memory_list_; std::vector weight_list_; std::unordered_map> op_map_; std::unordered_map> dynamic_op_map_; rtStream_t stream_ = nullptr; + std::mutex mu_; + std::mutex stream_mu_; }; } // namespace ge diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.cc b/src/ge/single_op/task/aicpu_kernel_task_builder.cc index 936c7b67..4264f8c5 100644 --- a/src/ge/single_op/task/aicpu_kernel_task_builder.cc +++ b/src/ge/single_op/task/aicpu_kernel_task_builder.cc @@ -17,7 +17,8 @@ #include "single_op/task/aicpu_kernel_task_builder.h" namespace ge { -AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const domi::KernelDef &kernel_def) : kernel_def_(kernel_def) {} +AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def) + : op_desc_(op_desc), kernel_def_(kernel_def) {} Status AiCpuCCTaskBuilder::SetKernelArgs(AiCpuCCTask &task) { size_t aicpu_arg_size = kernel_def_.args_size(); @@ -25,20 +26,21 @@ Status AiCpuCCTaskBuilder::SetKernelArgs(AiCpuCCTask &task) { GELOGE(RT_FAILED, "aicpu_arg_size is invalid, value = %zu", aicpu_arg_size); return RT_FAILED; } - void *aicpu_args = malloc(aicpu_arg_size); + std::unique_ptr aicpu_args; + aicpu_args.reset(new (std::nothrow) uint8_t[aicpu_arg_size]()); if (aicpu_args == nullptr) { GELOGE(RT_FAILED, "malloc failed, size = %zu", aicpu_arg_size); return RT_FAILED; } - task.SetKernelArgs(aicpu_args, aicpu_arg_size); - auto err = memcpy_s(aicpu_args, aicpu_arg_size, kernel_def_.args().data(), aicpu_arg_size); + auto err = memcpy_s(aicpu_args.get(), aicpu_arg_size, kernel_def_.args().data(), aicpu_arg_size); if (err != EOK) { GELOGE(RT_FAILED, "memcpy_s args failed, size = %zu, err = %d", aicpu_arg_size, err); return RT_FAILED; } - task.SetIoAddr(static_cast(aicpu_args) + sizeof(aicpu::AicpuParamHead)); + task.SetIoAddr(aicpu_args.get() + sizeof(aicpu::AicpuParamHead)); + task.SetKernelArgs(std::move(aicpu_args), aicpu_arg_size); return SUCCESS; } @@ -51,6 +53,7 @@ Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task) { const std::string &kernel_name = kernel_def_.kernel_name(); task.SetSoName(so_name); task.SetkernelName(kernel_name); + task.op_desc_ = op_desc_; return SUCCESS; } } // namespace ge \ No newline at end of file diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.h b/src/ge/single_op/task/aicpu_kernel_task_builder.h index c445132e..f9ca0530 100644 --- a/src/ge/single_op/task/aicpu_kernel_task_builder.h +++ b/src/ge/single_op/task/aicpu_kernel_task_builder.h @@ -18,6 +18,7 @@ #define GE_SINGLE_OP_TASK_AICPU_KERNEL_TASK_BUILDER_H_ #include +#include "graph/op_desc.h" #include "aicpu/common/aicpu_task_struct.h" #include "single_op/single_op.h" #include "single_op/single_op_model.h" @@ -26,13 +27,14 @@ namespace ge { class AiCpuCCTaskBuilder { public: - explicit AiCpuCCTaskBuilder(const domi::KernelDef &kernel_def); + explicit AiCpuCCTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def); ~AiCpuCCTaskBuilder() = default; Status BuildTask(AiCpuCCTask &task); private: Status SetKernelArgs(AiCpuCCTask &task); + const OpDescPtr op_desc_; const domi::KernelDef &kernel_def_; }; } // namespace ge diff --git a/src/ge/single_op/task/aicpu_task_builder.cc b/src/ge/single_op/task/aicpu_task_builder.cc index bc2c76f6..aba29f93 100644 --- a/src/ge/single_op/task/aicpu_task_builder.cc +++ b/src/ge/single_op/task/aicpu_task_builder.cc @@ -128,6 +128,7 @@ Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam task.io_addr_ = io_addr; task.task_info_ = kernel_def_.task_info(); task.workspace_addr_ = ws_addr_vec[0]; + task.op_desc_ = op_desc_; auto debug_info = BuildTaskUtils::GetTaskInfo(op_desc_); GELOGI("[TASK_INFO] %s %s", task.task_info_.c_str(), debug_info.c_str()); diff --git a/src/ge/single_op/task/op_task.cc b/src/ge/single_op/task/op_task.cc index ddc4992c..8280fff5 100644 --- a/src/ge/single_op/task/op_task.cc +++ b/src/ge/single_op/task/op_task.cc @@ -16,13 +16,15 @@ #include "single_op/task/op_task.h" +#include #include #include -#include -#include "runtime/rt.h" -#include "register/op_tiling.h" +#include "common/dump/dump_manager.h" +#include "common/dump/dump_op.h" #include "framework/common/debug/log.h" +#include "register/op_tiling.h" +#include "runtime/rt.h" namespace ge { namespace { @@ -30,15 +32,44 @@ constexpr int kLaunchRetryTimes = 1000; constexpr int kSleepTime = 10; } // namespace +Status OpTask::OpenDump(void *arg, const OpDescPtr &op_desc, rtStream_t stream) { + if (DumpManager::GetInstance().IsDumpOpen()) { + GELOGI("Dump is open in single op,start to set dump info"); + std::vector input_addrs; + std::vector output_adds; + auto input_size = op_desc->GetAllInputsDesc().size(); + auto output_size = op_desc->GetOutputsSize(); + for (size_t i = 0; i < input_size; i++) { + uint64_t input_addr = *(reinterpret_cast(arg) + i); + input_addrs.emplace_back(input_addr); + } + for (size_t j = 0; j < output_size; j++) { + uint64_t output_addr = *(reinterpret_cast(arg) + input_size + j); + output_adds.emplace_back(output_addr); + } + dump_op_.SetDumpInfo(DumpManager::GetInstance().GetDumpProperties(), op_desc, input_addrs, output_adds, stream); + auto status = dump_op_.LaunchDumpOp(); + if (status != SUCCESS) { + GELOGE(status, "Launch dump op failed in single op"); + return status; + } + return SUCCESS; + } + GELOGI("Dump is not open in single op"); + return SUCCESS; +} + void TbeOpTask::SetStubFunc(const std::string &name, const void *stub_func) { this->stub_name_ = name; this->stub_func_ = stub_func; } -void TbeOpTask::SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim) { +void TbeOpTask::SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, + const OpDescPtr &op_desc) { args_ = std::move(args); arg_size_ = arg_size; block_dim_ = block_dim; + op_desc_ = op_desc; } void TbeOpTask::SetSmDesc(void *sm_desc) { sm_desc_ = sm_desc; } @@ -79,8 +110,13 @@ Status TbeOpTask::LaunchKernel(rtStream_t stream) { GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->stub_name_.c_str()); return RT_FAILED; } - GELOGI("[TASK_INFO] %s", this->stub_name_.c_str()); + + auto status = OpenDump(args_.get(), op_desc_, stream); + if (status != SUCCESS) { + GELOGE(status, "Open dump failed in tbe single op %s", stub_name_.c_str()); + return status; + } return SUCCESS; } @@ -89,6 +125,7 @@ Status TbeOpTask::UpdateRunInfo(const vector &input_desc, const ve // invoke OpParaCalculate GELOGD("Start to invoke OpParaCalculate."); optiling::OpRunInfo run_info; + run_info.block_dim = 0; auto ret = optiling::OpParaCalculate(*node_, run_info); if (ret != GRAPH_SUCCESS) { GELOGE(FAILED, "Failed to invoke OpParaCalculate. ret = %u", ret); @@ -194,6 +231,7 @@ AiCpuTask::~AiCpuTask() { const void *AiCpuTask::GetIOAddr() const { return io_addr_; } Status AiCpuTask::LaunchKernel(rtStream_t stream) { + GELOGD("Start to launch kernel. task = %s", this->op_type_.c_str()); auto ret = rtMemcpyAsync(workspace_addr_, task_info_.size(), task_info_.data(), task_info_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream); if (ret != RT_ERROR_NONE) { @@ -201,20 +239,27 @@ Status AiCpuTask::LaunchKernel(rtStream_t stream) { return RT_FAILED; } - GELOGD("To invoke rtKernelLaunchEx. task = %s", this->op_type_.c_str()); + GELOGI("To invoke rtKernelLaunchEx. task = %s", this->op_type_.c_str()); ret = rtKernelLaunchEx(args_, arg_size_, 0, stream); if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->op_type_.c_str()); return RT_FAILED; } - GELOGI("[TASK_INFO] %s", this->task_info_.c_str()); + GELOGI("[TASK_INFO] is %s", this->task_info_.c_str()); + + auto status = OpenDump(args_, op_desc_, stream); + if (status != SUCCESS) { + GELOGE(status, "Open dump failed in aicpu single op %s", op_type_.c_str()); + return status; + } + GELOGD("Done launch kernel successfully. task = %s", this->op_type_.c_str()); return SUCCESS; } -void AiCpuCCTask::SetKernelArgs(void *args, size_t arg_size) { - args_ = args; +void AiCpuCCTask::SetKernelArgs(std::unique_ptr args, size_t arg_size) { + args_ = std::move(args); arg_size_ = arg_size; - // the blockdim value is defult "1" for rtCpuKernelLaunch + // The blockdim value is defult "1" for rtCpuKernelLaunch block_dim_ = 1; } @@ -226,16 +271,11 @@ void AiCpuCCTask::SetIoAddr(void *io_addr) { io_addr_ = io_addr; } const void *AiCpuCCTask::GetIOAddr() const { return io_addr_; } -const void *AiCpuCCTask::GetArgs() const { return args_; } +const void *AiCpuCCTask::GetArgs() const { return args_.get(); } size_t AiCpuCCTask::GetArgSize() const { return arg_size_; } -AiCpuCCTask::~AiCpuCCTask() { - if (args_ != nullptr) { - free(args_); - args_ = nullptr; - } -} +AiCpuCCTask::~AiCpuCCTask() {} Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { GELOGI("To invoke rtCpuKernelLaunch. block_dim = %u, so_name is %s, kernel_name is %s", block_dim_, so_name_.data(), @@ -244,12 +284,18 @@ Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { auto *sm_desc = reinterpret_cast(sm_desc_); auto ret = rtCpuKernelLaunch(static_cast(so_name_.data()), static_cast(kernel_name_.data()), - block_dim_, args_, static_cast(arg_size_), sm_desc, stream); + block_dim_, args_.get(), static_cast(arg_size_), sm_desc, stream); if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Invoke rtCpuKernelLaunch failed. ret = %d", ret); return RT_FAILED; } GELOGD("Invoke rtCpuKernelLaunch succeeded"); + + auto status = OpenDump(args_.get(), op_desc_, stream); + if (status != SUCCESS) { + GELOGE(status, "Open dump failed in aicpucc single op"); + return status; + } return SUCCESS; } } // namespace ge diff --git a/src/ge/single_op/task/op_task.h b/src/ge/single_op/task/op_task.h index 3e261b3f..0401a177 100644 --- a/src/ge/single_op/task/op_task.h +++ b/src/ge/single_op/task/op_task.h @@ -21,9 +21,11 @@ #include #include -#include "runtime/stream.h" +#include "common/dump/dump_op.h" +#include "common/dump/dump_properties.h" #include "common/ge_inner_error_codes.h" #include "graph/op_kernel_bin.h" +#include "runtime/stream.h" #include "graph/node.h" namespace ge { @@ -47,12 +49,17 @@ class OpTask { return UNSUPPORTED; } virtual OpTaskType GetOpTaskType() = 0; - + virtual const void *GetIOAddr() const = 0; const vector &GetWorkspaceSizes() const; void SetWorkspaceSizes(const vector &workspace_sizes); private: std::vector workspace_sizes_; + + protected: + Status OpenDump(void *arg, const OpDescPtr &op_desc, rtStream_t stream); + DumpProperties dump_properties_; + DumpOp dump_op_; }; class TbeOpTask : public OpTask { @@ -60,10 +67,10 @@ class TbeOpTask : public OpTask { ~TbeOpTask() override; Status LaunchKernel(rtStream_t stream) override; OpTaskType GetOpTaskType() override { return OP_TASK_TBE; } - + const void *GetIOAddr() const override { return nullptr; } void SetSmDesc(void *sm_desc); void SetStubFunc(const std::string &name, const void *stub_func); - void SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim); + void SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, const OpDescPtr &op_desc); Status UpdateRunInfo(const vector &input_desc, const vector &output_desc) override; @@ -90,6 +97,7 @@ class TbeOpTask : public OpTask { uint32_t max_tiling_size_ = 0; std::string tiling_data_; NodePtr node_; + OpDescPtr op_desc_; }; class AiCpuTask : public OpTask { @@ -99,7 +107,7 @@ class AiCpuTask : public OpTask { Status LaunchKernel(rtStream_t stream) override; OpTaskType GetOpTaskType() override { return OP_TASK_AICPU; } - const void *GetIOAddr() const; + const void *GetIOAddr() const override; private: friend class AiCpuTaskBuilder; @@ -109,6 +117,7 @@ class AiCpuTask : public OpTask { size_t arg_size_ = 0; std::string op_type_; void *io_addr_ = nullptr; + OpDescPtr op_desc_; }; class AiCpuCCTask : public OpTask { @@ -120,9 +129,9 @@ class AiCpuCCTask : public OpTask { Status LaunchKernel(rtStream_t stream) override; OpTaskType GetOpTaskType() override { return OP_TASK_AICPUCC; } - const void *GetIOAddr() const; + const void *GetIOAddr() const override; const void *GetArgs() const; - void SetKernelArgs(void *args, size_t arg_size); + void SetKernelArgs(std::unique_ptr args, size_t arg_size); void SetSoName(const std::string &so_name); void SetkernelName(const std::string &kernel_Name); void SetIoAddr(void *io_addr); @@ -132,11 +141,12 @@ class AiCpuCCTask : public OpTask { friend class AiCpuCCTaskBuilder; std::string so_name_; std::string kernel_name_; - void *args_ = nullptr; + std::unique_ptr args_; size_t arg_size_ = 0; uint32_t block_dim_ = 1; void *sm_desc_ = nullptr; void *io_addr_ = nullptr; + OpDescPtr op_desc_; }; } // namespace ge diff --git a/src/ge/single_op/task/tbe_task_builder.cc b/src/ge/single_op/task/tbe_task_builder.cc index 23c023fd..935c62a3 100644 --- a/src/ge/single_op/task/tbe_task_builder.cc +++ b/src/ge/single_op/task/tbe_task_builder.cc @@ -19,8 +19,8 @@ #include #include -#include "graph/load/new_model_manager/model_utils.h" #include "graph/debug/ge_attr_define.h" +#include "graph/load/new_model_manager/model_utils.h" #include "graph/manager/graph_var_manager.h" #include "runtime/rt.h" #include "single_op/task/build_task_utils.h" @@ -49,13 +49,6 @@ KernelHolder::~KernelHolder() { } } -KernelBinRegistry::~KernelBinRegistry() { - for (auto &iter : registered_bins_) { - delete iter.second; - iter.second = nullptr; - } -} - const char *KernelBinRegistry::GetUnique(const string &stub_func) { std::lock_guard lock(mutex_); auto it = unique_stubs_.find(stub_func); @@ -77,9 +70,9 @@ const char *KernelBinRegistry::GetStubFunc(const std::string &stub_name) { return nullptr; } -bool KernelBinRegistry::AddKernel(const std::string &stub_name, const KernelHolder *holder) { +bool KernelBinRegistry::AddKernel(const std::string &stub_name, std::unique_ptr &&holder) { std::lock_guard lock(mutex_); - auto ret = registered_bins_.emplace(stub_name, holder); + auto ret = registered_bins_.emplace(stub_name, std::move(holder)); return ret.second; } @@ -184,7 +177,7 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam return PARAM_INVALID; } - auto *holder = new (std::nothrow) KernelHolder(stub_func, tbe_kernel); + auto holder = std::unique_ptr(new (std::nothrow) KernelHolder(stub_func, tbe_kernel)); if (holder == nullptr) { GELOGE(MEMALLOC_FAILED, "create KernelHodler failed."); return MEMALLOC_FAILED; @@ -194,16 +187,11 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam auto ret = DoRegisterKernel(*tbe_kernel, stub_func, &bin_handle, param); if (ret == SUCCESS) { holder->SetBinHandle(bin_handle); - if (!registry.AddKernel(stub_name_, holder)) { + if (!registry.AddKernel(stub_name_, std::move(holder))) { // should not happen. only one thread can reach here - delete holder; - holder = nullptr; GELOGE(INTERNAL_ERROR, "Add kernel failed. stub name = %s", stub_name_.c_str()); return INTERNAL_ERROR; } - } else { - delete holder; - holder = nullptr; } } @@ -245,7 +233,7 @@ Status TbeTaskBuilder::GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m return SUCCESS; } -Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m) { +Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m, const OpDescPtr &op_desc) { size_t arg_size = kernel_def_.args_size(); auto args = std::unique_ptr(new (std::nothrow) uint8_t[arg_size]); GE_CHECK_NOTNULL(args); @@ -276,13 +264,13 @@ Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam & } } - task.SetKernelArgs(std::move(args), arg_size, kernel_def_.block_dim()); + task.SetKernelArgs(std::move(args), arg_size, kernel_def_.block_dim(), op_desc); return SUCCESS; } Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶m) { GELOGD("Build tbe task begin"); - auto ret = SetKernelArgs(task, param); + auto ret = SetKernelArgs(task, param, op_desc_); if (ret != SUCCESS) { return ret; } diff --git a/src/ge/single_op/task/tbe_task_builder.h b/src/ge/single_op/task/tbe_task_builder.h index 7c5f8054..5cd5c463 100644 --- a/src/ge/single_op/task/tbe_task_builder.h +++ b/src/ge/single_op/task/tbe_task_builder.h @@ -44,8 +44,6 @@ class KernelHolder { class KernelBinRegistry { public: - ~KernelBinRegistry(); - static KernelBinRegistry &GetInstance() { static KernelBinRegistry instance; return instance; @@ -55,10 +53,10 @@ class KernelBinRegistry { const char *GetStubFunc(const std::string &stub_name); - bool AddKernel(const std::string &stub_name, const KernelHolder *holder); + bool AddKernel(const std::string &stub_name, std::unique_ptr &&holder); private: - std::map registered_bins_; + std::map> registered_bins_; std::set unique_stubs_; std::mutex mutex_; }; @@ -72,7 +70,7 @@ class TbeTaskBuilder { private: Status InitTilingInfo(TbeOpTask &task); - Status SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m); + Status SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m, const OpDescPtr &op_desc); Status GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m) const; Status RegisterKernel(TbeOpTask &task, const SingleOpModelParam ¶m); diff --git a/src/proto/op_mapping_info.proto b/src/proto/op_mapping_info.proto index a02af28b..7b84a115 100644 --- a/src/proto/op_mapping_info.proto +++ b/src/proto/op_mapping_info.proto @@ -41,6 +41,16 @@ message Input { uint64 size = 5; } +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + uint64 address = 2; + uint64 size = 3; +} + message Op { string op_name = 1; string op_type = 2; @@ -53,6 +63,7 @@ message Task { repeated Output output = 4; bool end_graph = 5; repeated Input input = 6; + repeated OpBuffer buffer = 7; } message OpMappingInfo { diff --git a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h index 023812dd..957117cc 100644 --- a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h +++ b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h @@ -57,6 +57,7 @@ enum FWKTaskExtInfoType { FWK_ADPT_EXT_INPUT_SHAPE, FWK_ADPT_EXT_OUTPUT_SHAPE, FWK_ADPT_EXT_UPDATE_ADDR, + FWK_ADPT_EXT_OP_NAME, FWK_ADPT_EXT_INVALID }; diff --git a/third_party/fwkacllib/inc/ops/aipp.h b/third_party/fwkacllib/inc/ops/aipp.h index 9dc5a018..85666223 100644 --- a/third_party/fwkacllib/inc/ops/aipp.h +++ b/third_party/fwkacllib/inc/ops/aipp.h @@ -44,13 +44,20 @@ REG_OP(Aipp) } // namespace ge /** -*@brief Performs This op is for dynamic aipp.If you set aipp-mode to dynamic in aipp config file, framework will auto add one input node to graph at last. +*@brief Performs this op is for dynamic aipp.If you set aipp-mode to dynamic \n +in aipp config file, framework will auto add one input node to graph at last. + +*@par Inputs: +*data: An NCHW or NHWC tensor of type uint8, specifying the input to the data layer. *@par Attributes: *index: specify aipp serial num *@par Outputs: -*features: The AIPP-processed output tensor of all types. +*out: The AIPP-processed output tensor of all types. + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AippData. */ namespace ge { REG_OP(AippData) diff --git a/third_party/fwkacllib/inc/ops/control_flow_ops.h b/third_party/fwkacllib/inc/ops/control_flow_ops.h index 77980b67..fa68d49a 100644 --- a/third_party/fwkacllib/inc/ops/control_flow_ops.h +++ b/third_party/fwkacllib/inc/ops/control_flow_ops.h @@ -156,6 +156,19 @@ REG_OP(RefSwitch) DT_UINT64, DT_BOOL})) .OP_END_FACTORY_REG(RefSwitch) +/** + *@brief Forwards "data" to the output port determined by "pred_value". + + *@par Inputs: + *@li data: The tensor to be forwarded. \ n + * Must be one of the following types: float16, float32, float64, \n + * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. + *@li pred_value: A int64 tensor which determines the output port that will receive data. + + *@par Outputs: + *output: The output tensors, one of which will become available. \n + * Has the same type as "data". + */ REG_OP(SwitchN) .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, @@ -166,7 +179,6 @@ REG_OP(SwitchN) DT_UINT64, DT_BOOL})) .OP_END_FACTORY_REG(SwitchN) - /** *@brief Creates or finds a child frame, and makes "x" available to the child \n * frame. This op is used together with Exit to create loops in the graph. \n diff --git a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h index 1022880f..741a9071 100644 --- a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h @@ -2909,6 +2909,26 @@ REG_OP(Bias) .ATTR(bias_from_blob, Bool, true) .OP_END_FACTORY_REG(Bias) +/** +*@brief Function multiply gradients calculation. \n +output0 is the result of which input0 dot multily input1. +output1 is the result of which input0 dot multily input1, then reducesum it. + +*@par Inputs: +*@li input0: A Tensor of input of mul, and dtype supports float16, float32. +*@li input1: A Tensor of input of mul and mul_1, and dtype supports float16, float32. +*@li input2: A Tensor of input of mul_1, and dtype supports float16, float32'. + +*@par Attributes: +*@li axes: The dimensions to reduce. Default:(), reduce all dimensions. \n +Only constant value is allowed. +*@li keep_dims: If true, keep these reduced dimensions and the length is 1. \n +If false, don’t keep these dimensions. Default:False. + +*@par Outputs: +*@li output0: A Tensor result of which input0 dot multily input1. +*@li output1: A Tensor result of which input0 dot multily input1, then reducesum it. +*/ REG_OP(ConfusionMulGrad) .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) @@ -2919,6 +2939,19 @@ REG_OP(ConfusionMulGrad) .ATTR(keep_dims, Bool, false) .OP_END_FACTORY_REG(ConfusionMulGrad) +/** +*@brief Function fused multiply l2 loss calculation. \n + +*@par Inputs: +*@li x1: A Tensor of type float16, float32. +*@li x2: A Tensor of type float16, float32. +*@li x3: A Tensor of type float16, float32. + +*@par Outputs: +*@li y1: A Tensor of shape and dtype of first output, which should have \n +shape (1,) and dtype as input. +*@li y2: A Tensor of shape and dtype of second output, should be same shape and type as input. +*/ REG_OP(FusedMulAddNL2loss) .INPUT(x1, TensorType::NumberType()) .INPUT(x2, TensorType::NumberType()) @@ -2927,7 +2960,6 @@ REG_OP(FusedMulAddNL2loss) .OUTPUT(y2, TensorType::NumberType()) .OP_END_FACTORY_REG(FusedMulAddNL2loss) - /** *@brief Tests whether the input exceeds a threshold. @@ -3056,6 +3088,23 @@ REG_OP(Fills) .OUTPUT(y, TensorType::NumberType()) /* "Result, has same element type as two inputs" */ .OP_END_FACTORY_REG(MulNoNan) +/** +*@brief Add tensor with scale. + +*@par Inputs: +*Five inputs, including: +* @li x1: A Tensor dtype of int32, float16, float32. +* @li x2: A Tensor dtype of int32, float16, float32. + +*@par Attributes: +*alpha: Float scalar apply to x2:x2*alpha + +*@par Outputs: +*y: A Tensor. should be same shape and type as "x1". + +*@par Third-party framework compatibility: +* Compatible with the Pytorch operator Axpy. +*/ REG_OP(Axpy) .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) diff --git a/third_party/fwkacllib/inc/ops/functional_ops.h b/third_party/fwkacllib/inc/ops/functional_ops.h index a297bc61..f4a88661 100644 --- a/third_party/fwkacllib/inc/ops/functional_ops.h +++ b/third_party/fwkacllib/inc/ops/functional_ops.h @@ -34,6 +34,31 @@ REG_OP(RemoteCall) .GRAPH(f) .OP_END_FACTORY_REG(RemoteCall) +/** + *@brief Select one of the subgraphs to pass the input tensors and return the output tensors. \n + * If "cond" means True, the selected subgraph is "then_branch". \n + * Otherwise, the selected subgraph is "else_branch". + + *@par Inputs: + *@li cond: A Tensor. If "cond" is not a scalar of boolean type, \n + * it will be converted to a boolean according to the following rule: \n + * if "cond" is a numerical scalar, non-zero means True and zero means False; \n + * if "cond" is a string scalar, non-empty means True and empty means False; \n + * if "cond" is not a scalar, non-empty means True and empty means False. + *@li input: The input tensors. + + *@par Graphs: + *@li then_branch: A subgraph takes 'input' and returns a list of tensors, \n + * whose types are the same as what else_branch returns. + *@li else_branch: A subgraph takes 'input' and returns a list of tensors, \n + * whose types are the same as what then_branch returns. + + *@par Outputs: + *output: The output tensors returned by either then_branch(input) or else_branch(input). + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator _If. + */ REG_OP(_If) .INPUT(cond, TensorType::ALL()) .DYNAMIC_INPUT(input, TensorType::ALL()) @@ -42,6 +67,31 @@ REG_OP(_If) .GRAPH(else_branch) .OP_END_FACTORY_REG(_If) +/** + *@brief Select one of the subgraphs to pass the input tensors and return the output tensors. \n + * If "cond" means True, the selected subgraph is "then_branch". \n + * Otherwise, the selected subgraph is "else_branch". + + *@par Inputs: + *@li cond: A Tensor. If "cond" is not a scalar of boolean type, \n + * it will be converted to a boolean according to the following rule: \n + * if "cond" is a numerical scalar, non-zero means True and zero means False; \n + * if "cond" is a string scalar, non-empty means True and empty means False; \n + * if "cond" is not a scalar, non-empty means True and empty means False. + *@li input: The input tensors. + + *@par Graphs: + *@li then_branch: A subgraph takes 'input' and returns a list of tensors, \n + * whose types are the same as what else_branch returns. + *@li else_branch: A subgraph takes 'input' and returns a list of tensors, \n + * whose types are the same as what then_branch returns. + + *@par Outputs: + *output: The output tensors returned by either then_branch(input) or else_branch(input). + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator StatelessIf. + */ REG_OP(StatelessIf) .INPUT(cond, TensorType::ALL()) .DYNAMIC_INPUT(input, TensorType::ALL()) @@ -50,6 +100,31 @@ REG_OP(StatelessIf) .GRAPH(else_branch) .OP_END_FACTORY_REG(StatelessIf) +/** + *@brief Select one of the subgraphs to pass the input tensors and return the output tensors. \n + * If "cond" means True, the selected subgraph is "then_branch". \n + * Otherwise, the selected subgraph is "else_branch". + + *@par Inputs: + *@li cond: A Tensor. If "cond" is not a scalar of boolean type, \n + * it will be converted to a boolean according to the following rule: \n + * if "cond" is a numerical scalar, non-zero means True and zero means False; \n + * if "cond" is a string scalar, non-empty means True and empty means False; \n + * if "cond" is not a scalar, non-empty means True and empty means False. + *@li input: The input tensors. + + *@par Graphs: + *@li then_branch: A subgraph takes 'input' and returns a list of tensors, \n + * whose types are the same as what else_branch returns. + *@li else_branch: A subgraph takes 'input' and returns a list of tensors, \n + * whose types are the same as what then_branch returns. + + *@par Outputs: + *output: The output tensors returned by either then_branch(input) or else_branch(input). + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator If. + */ REG_OP(If) .INPUT(cond, TensorType::ALL()) .DYNAMIC_INPUT(input, TensorType::ALL()) @@ -58,6 +133,23 @@ REG_OP(If) .GRAPH(else_branch) .OP_END_FACTORY_REG(If) +/** + *@brief Select one of the subgraphs to pass the input tensors and return the output tensors. + + *@par Inputs: + *@li branch_index: A int32 scalar which determines the selected subgraph. + *@li input: The input tensors, which will be passed to the subgraph. + + *@par Graphs: + *branches: A list of subgraphs, each of which takes 'input' and returns a list of tensors, \n + * whose types are the same as what every other subgraph returns. + + *@par Outputs: + *output: The output tensors returned by one of branches. + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator Case. + */ REG_OP(Case) .INPUT(branch_index, DT_INT32) .DYNAMIC_INPUT(input, TensorType::ALL()) @@ -65,6 +157,30 @@ REG_OP(Case) .DYNAMIC_GRAPH(branches) .OP_END_FACTORY_REG(Case) +/** + *@brief Cyclic execute the "body" subgraph until the return tensor of "cond" subgraph means False. + + *@par Inputs: + *input: The input tensors. + + *@par Graphs: + *@li cond: A subgraph takes 'input' and returns a tensor. \n + * If the tensor is not a scalar of boolean type, \n + * it will be converted to a boolean according to the following rule: \n + * if it is a numerical scalar, non-zero means True and zero means False; \n + * if it is a string scalar, non-empty means True and empty means False; \n + * if it is not a scalar, non-empty means True and empty means False. + *@li body: A subgraph takes 'input' and returns a another list of tensors. + + *@par Attributes: + *parallel_iterations: An optional int, default as 10. + + *@par Outputs: + *output: The output tensors returned by "body". Has the same type as "input". + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator _While. + */ REG_OP(_While) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) @@ -72,6 +188,30 @@ REG_OP(_While) .GRAPH(body) .OP_END_FACTORY_REG(_While) +/** + *@brief Cyclic execute the "body" subgraph until the return tensor of "cond" subgraph means False. + + *@par Inputs: + *input: The input tensors. + + *@par Graphs: + *@li cond: A subgraph takes 'input' and returns a tensor. \n + * If the tensor is not a scalar of boolean type, \n + * it will be converted to a boolean according to the following rule: \n + * if it is a numerical scalar, non-zero means True and zero means False; \n + * if it is a string scalar, non-empty means True and empty means False; \n + * if it is not a scalar, non-empty means True and empty means False. + *@li body: A subgraph takes 'input' and returns a another list of tensors. + + *@par Attributes: + *parallel_iterations: An optional int, default as 10. + + *@par Outputs: + *output: The output tensors returned by "body". Has the same type as "input". + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator While. + */ REG_OP(While) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) @@ -80,6 +220,30 @@ REG_OP(While) .ATTR(parallel_iterations, Int, 10) .OP_END_FACTORY_REG(While) +/** + *@brief Cyclic execute the "body" subgraph until the return tensor of "cond" subgraph means False. + + *@par Inputs: + *input: The input tensors. + + *@par Graphs: + *@li cond: A subgraph takes 'input' and returns a tensor. \n + * If the tensor is not a scalar of boolean type, \n + * it will be converted to a boolean according to the following rule: \n + * if it is a numerical scalar, non-zero means True and zero means False; \n + * if it is a string scalar, non-empty means True and empty means False; \n + * if it is not a scalar, non-empty means True and empty means False. + *@li body: A subgraph takes 'input' and returns a another list of tensors. + + *@par Attributes: + *parallel_iterations: An optional int, default as 10. + + *@par Outputs: + *output: The output tensors returned by "body". Has the same type as "input". + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator StatelessWhile. + */ REG_OP(StatelessWhile) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) @@ -88,6 +252,24 @@ REG_OP(StatelessWhile) .ATTR(parallel_iterations, Int, 10) .OP_END_FACTORY_REG(StatelessWhile) +/** + *@brief Cyclic execute the "body" subgraph until the first input of For op exceed upper bound. + + *@par Inputs: + *@li start: A int32 scalar. The lower bound. + *@li limit: A int32 scalar. The upper bound. + *@li delta: A int32 scalar. The step size. + *@li input: The input tensors, which will be passed to "body". + + *@par Graphs: + *body: A subgraph takes 'input' and returns a another list of tensors. + + *@par Outputs: + *output: The output tensors returned by "body". Has the same type as "input". + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator For. + */ REG_OP(For) .INPUT(start, DT_INT32) .INPUT(limit, DT_INT32) @@ -97,6 +279,26 @@ REG_OP(For) .GRAPH(body) .OP_END_FACTORY_REG(For) +/** + *@brief Pass the input tensors to the subgraph "f" and return the output tensors. + + *@par Inputs: + *args: The input tensors, which will be passed to "f". + + *@par Graphs: + *f: A subgraph takes 'args' and returns a another list of tensors. + + *@par Attributes: + *@li config: An optional string, default as "". + *@li config_proto: An optional int, default as "". + *@li executor_type: An optional int, default as "". + + *@par Outputs: + *output: The output tensors returned by "f". + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator PartitionedCall. + */ REG_OP(PartitionedCall) .DYNAMIC_INPUT(args, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) @@ -106,6 +308,26 @@ REG_OP(PartitionedCall) .ATTR(executor_type, String, "") .OP_END_FACTORY_REG(PartitionedCall) +/** + *@brief Pass the input tensors to the subgraph "f" and return the output tensors. + + *@par Inputs: + *args: The input tensors, which will be passed to "f". + + *@par Graphs: + *f: A subgraph takes 'args' and returns a another list of tensors. + + *@par Attributes: + *@li config: An optional string, default as "". + *@li config_proto: An optional int, default as "". + *@li executor_type: An optional int, default as "". + + *@par Outputs: + *output: The output tensors returned by "f". + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator StatefulPartitionedCall. + */ REG_OP(StatefulPartitionedCall) .DYNAMIC_INPUT(args, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) diff --git a/third_party/fwkacllib/inc/ops/image_ops.h b/third_party/fwkacllib/inc/ops/image_ops.h index 59b99841..1ea62fa9 100644 --- a/third_party/fwkacllib/inc/ops/image_ops.h +++ b/third_party/fwkacllib/inc/ops/image_ops.h @@ -143,6 +143,45 @@ REG_OP(CropAndResize) .ATTR(method, String, "bilinear") .OP_END_FACTORY_REG(CropAndResize) +/** +*@brief Extracts crops from the input image tensor and resizes them. Extracts \n +crops from the input image tensor and resizes them using bilinear sampling or \n +nearest neighbor sampling to a common output size specified by crop_size. + +*@par Inputs: +*Input images must be a 5HD tensor. Inputs include: \n +*@li images:A Tensor. Must be one of the following types:float. A 5HD tensor of shape \n +[batch, C1, image_height, image_width, C0]. +*@li boxes: A Tensor of type float. A 2-D tensor of shape [num_boxes, 4]. +*@li box_index: A Tensor of type int32. A 1-D tensor of shape [num_boxes] with \n +int32 values in [0, batch - 1). + +*@par Attributes: +*@li crop_size: list int. [crop_height, crop_width]. All cropped image patches are resized to this size. +*@li extrapolation_value: An optional float. Defaults to 0. Value used for \n +extrapolation, when applicable. +*@li method: An optional string from: '"bilinear"'. Defaults to \n +"bilinear". + +*@par Outputs: +*y:A Tensor of type float. + +*@attention Constraints: \n +*Input images must be a 5HD tensor. + +*@par Third-party framework compatibility +*Compatible with tensorflow CropAndResize operator. +*/ +REG_OP(CropAndResizeD) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(boxes, TensorType({DT_FLOAT})) + .INPUT(box_index, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .REQUIRED_ATTR(crop_size, ListInt) + .ATTR(extrapolation_value, Float, 0) + .ATTR(method, String, "bilinear") + .OP_END_FACTORY_REG(CropAndResizeD) + /** *@brief Computes the gradient of the crop_and_resize op wrt the input \n boxes tensor. @@ -1257,6 +1296,22 @@ REG_OP(CombinedNonMaxSuppression) .ATTR(clip_boxes, Bool, true) .OP_END_FACTORY_REG(CombinedNonMaxSuppression) +/** +*@brief Function spatial transformer. + +*@par Inputs: +*@li x: A Tensor dtype of float16, float32. +*@li theta: A Tensor dtype of float16, float32, auxiliary coefficients. + +*@par Attributes: +*@li output_size: A tuple output size. +*@li default_theta: A tuple default theta +*@li use_default_theta: List use default theta +*@li align_corners: Align corners + +*@par Outputs: +*y: A Tensor dtype of float16, float32, should be same shape and type as x. +*/ REG_OP(SpatialTransformerD) .INPUT(x, TensorType({DT_FLOAT,DT_FLOAT16})) .OPTIONAL_INPUT(theta, TensorType({DT_FLOAT,DT_FLOAT16})) diff --git a/third_party/fwkacllib/inc/ops/internal_ops.h b/third_party/fwkacllib/inc/ops/internal_ops.h index 8c261382..0f9fd12f 100644 --- a/third_party/fwkacllib/inc/ops/internal_ops.h +++ b/third_party/fwkacllib/inc/ops/internal_ops.h @@ -56,6 +56,22 @@ REG_OP(CacheUpdate) .OUTPUT(x, TensorType::BasicType()) .OP_END_FACTORY_REG(CacheUpdate) +/** +*@brief transfer data from L1 buffer to DDR or DDR to L1. + +*@par Inputs: +*The input is dynamic for attribute func_name \n + +*@par Outputs: +*The output is dynamic for attribute func_name. +*/ +REG_OP(InternalDataMove) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .REQUIRED_ATTR(src_buf, String) + .REQUIRED_ATTR(dst_buf, String) + .OP_END_FACTORY_REG(InternalDataMove) + } // namespace ge #endif // GE_OP_INTERNAL_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/math_ops.h b/third_party/fwkacllib/inc/ops/math_ops.h index b0c35c28..6d1e2cd2 100644 --- a/third_party/fwkacllib/inc/ops/math_ops.h +++ b/third_party/fwkacllib/inc/ops/math_ops.h @@ -575,25 +575,24 @@ REG_OP(Conj) .OP_END_FACTORY_REG(Conj) /** - * *@brief The negative log likelihood loss. - * - * *@par Inputs: - * *The input x and weight must have the same type. Inputs include: \n - * *@li x:A Tensor. Must be the type: float32. - * *@li target:A Tensor. Must be the type: int32. - * *@li weight:A Tensor. Must be the type: float32. - * - * *@par Attributes: - * *@li reduction: An optional attribute. Defaults to "mean". - * - * *@par Outputs: - * *Two outputs, including: - * *@li y: A Tensor. Must be the following type: float32. - * *@li total_weight: A Tensor. Must be the type: float32. - * - * *@par Third-party framework compatibility - * *Compatible with pytorch NLLLoss operator - * */ +*@brief The negative log likelihood loss. + +*@par Inputs: +*The input x and weight must have the same type. Inputs include: \n +*@li x: A Tensor dtype of float32. +*@li target: A Tensor dtype of int32. +*@li weight: A Tensor dtype of float32. + +*@par Attributes: +*reduction: An optional attribute. Defaults to "mean". + +*@par Outputs: +*@li y: A Tensor dtype of float32. +*@li total_weight: A Tensor dtype of float32. + +*@par Third-party framework compatibility +*Compatible with pytorch NLLLoss operator +*/ REG_OP(NLLLoss) .INPUT(x, TensorType({DT_FLOAT})) .INPUT(target, TensorType({DT_INT32})) @@ -604,26 +603,24 @@ REG_OP(NLLLoss) .OP_END_FACTORY_REG(NLLLoss) /** - * *@brief The negative log likelihood loss grad. +*@brief The negative log likelihood loss grad. - * *@par Inputs: - * *Inputs include: - * *@li x:A Tensor. Must be the type: float32. - * *@li y_grad:A Tensor. Must be the type: float32. - * *@li target:A Tensor. Must be the type: int32. - * *@li weight:A Tensor. Must be the type: float32. - * *@li total_weight:A Tensor. Must be the type: float32. - * - * *@par Attributes: - * *@li reduction: An optional attribute. Defaults to "mean". - * - * *@par Outputs: - * *One outputs, including: - * *@li x_grad: A Tensor. Must be the following type: float32. - * - * *@par Third-party framework compatibility - * *Compatible with pytorch NLLLossGrad operator - * */ +*@par Inputs: +*@li x:A Tensor dtype of float32. +*@li y_grad:A Tensor dtype of float32. +*@li target:A Tensor dtype of int32. +*@li weight:A Tensor dtype of float32. +*@li total_weight:A Tensor dtype of float32. + +*@par Attributes: +*reduction: An optional attribute. Defaults to "mean". + +*@par Outputs: +*x_grad: A Tensor. Must be the following type: float32. + +*@par Third-party framework compatibility +*Compatible with pytorch NLLLossGrad operator +*/ REG_OP(NLLLossGrad) .INPUT(x, TensorType({DT_FLOAT})) .INPUT(y_grad, TensorType({DT_FLOAT})) diff --git a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h index e9180332..b2cf56ad 100644 --- a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h @@ -410,20 +410,21 @@ REG_OP(Conv2DBackpropInputD) *@brief Computes the Deconvolution with respect to the input. *@par Inputs: * Three inputs: - * @li x: A Tensor. Must have the same type as "filter". 4D with shape + * @li x: A Tensor of type float16 or int8. 4D with shape * [batch, out_channels, out_height, out_width]. Gradients with respect * to the output of the convolution. - * @li filter: A Tensor of type float16, float32, double or int8. + * @li filter: A Tensor. Must have the same type as "x". * 4D with shape [out_channels, in_channel, filter_height, filter_width].\n * Two optional inputs: - * @li bias: An optional tensor of type float16, float32, int32 or int64. - * @li offset_w: An optional 1D tensor for quantized deconvolution. Type is int8. Reserved.\n + * @li bias: An optional tensor. Must have the same type as "y". + * @li offset_w: An optional 1D tensor for quantized deconvolution. + * Type is int8. Reserved.\n *@par Attributes: * Six attributes: * @li strides: A tuple or list of 2 integers. The stride of the sliding window - * for H/W dimension. Defaults to [1, 1, 1, 1]. + * for H/W dimension. * @li pads: A tuple or list of 4 integers. The [top, bottom, left, right] - * padding on the feature map. Defaults to [0, 0, 0, 0]. + * padding on the feature map. * @li dilations: A tuple or list of 4 integers. The dilation factor for each * dimension of input. Must be [1, 1, 1, 1]. * @li groups: Number of blocked connections from input channels to @@ -432,17 +433,18 @@ REG_OP(Conv2DBackpropInputD) Specify the data format of the input and output data. * @li offset_x: An optional integer for quantized deconvolution. Defaults to "0". *@par Outputs: - * y: A Tensor. Has the same type as "filter". 4D tensor with shape - * [batch, channels, height, width]. + * y: A Tensor. 4D tensor with shape [batch, channels, height, width]. + * When type of x is float16, the type of y must be float16. + * When type of x is int8, the type of y must be int32. */ REG_OP(Deconvolution) - .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) - .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32})) + .INPUT(x, TensorType({DT_FLOAT16, DT_INT8})) + .INPUT(filter, TensorType({DT_FLOAT16, DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_INT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32})) - .ATTR(strides, ListInt, {1, 1}) - .ATTR(pads, ListInt, {0, 0, 0, 0}) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1}) .ATTR(groups, Int, 1) .ATTR(data_format, String, "NCHW") @@ -502,7 +504,7 @@ REG_OP(Conv2DBackpropFilter) * @li groups: Number of blocked connections from input channels to output channels. * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to "NHWC". Specify the data format of the input and output data. *@par Outputs: - * y: A Tensor. Has the same type as x + * y: A Tensor. Type is float32 *@par Third-party framework compatibility * Compatible with Tensorflow's conv2d_backprop_filter */ @@ -525,7 +527,7 @@ REG_OP(Conv2DBackpropFilterD) * @li filter: A 4D tensor of filters. * @li bias: An optional 1D tensor. * @li offset_w: An optional 1D tensor for quantized convolution. Reserved. -* +* * The input and output tensor attributes are listed as follows: * @verbatim |Tensor | x | filter | bias | offset_w | y @@ -636,7 +638,8 @@ REG_OP(Conv2DCompress) /** *@brief Computes a 3D convolution given 5D "x" and "filter" tensors. *@par Inputs: - * @li x: A 5D tensor. Must be one of the following types: float16, float32, float64. The format is NCDHW or NDHWC. + * @li x: A 5D tensor. Must be one of the following types: float16, (Currently does not support int8). + * The format of x is NCDHW or NDHWC. * @li filter: A 5D tensor of the same type as "x". The format is NCDHW, NDHWC or DHWCN. *@par Optional input: @@ -644,13 +647,15 @@ REG_OP(Conv2DCompress) * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved. *@par Required Attributes: -* @li strides: A list of 5 ints. Specifies the stride of the sliding window for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". -* @li pads: A list of 6 ints. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. + * @li strides: A list of 5 integers. Specifies the stride of the sliding window for each dimension of "x". + * The N and C dimensions must be 1. Has the same format as "x". + * @li pads: A list of 6 integers. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. *@par Attributes: * @li groups: Number of blocked connections from input channels to output channels. * @li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. - * @li dilations: A list of 5 ints. Specifies the dilation factor for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". + * @li dilations: A list of 5 integers. Specifies the dilation factor for each dimension of "x". + * The N and C dimensions must be 1. Has the same format as "x". * @li offset_x: An optional int. Input offset, used for quantized inference. Defaults to 0. *@par Outputs: @@ -664,11 +669,11 @@ REG_OP(Conv2DCompress) * @li Compatible with the Caffe operator Convolution. */ REG_OP(Conv3D) - .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) - .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(x, TensorType({DT_FLOAT16})) + .INPUT(filter, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16})) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) @@ -677,6 +682,7 @@ REG_OP(Conv3D) .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Conv3D) + /** *@brief Computes the gradients of convolution 3d with respect to the input. *@par Inputs: @@ -688,14 +694,15 @@ REG_OP(Conv3D) * or [batch, out_channels, depth, out_height, out_width]. Gradients with respect to the output of the convolution. *@par Required Attributes: - * @li strides: A list of 5 ints. Specifies the stride of the sliding window for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". - * @li pads: A list of 6 ints. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. + * @li strides: A list of 5 integers. Specifies the stride of the sliding window for each dimension of "x". + * The N and C dimensions must be 1. Has the same format as "x". + * @li pads: A list of 6 integers. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. *@par Attributes: * Three attributes: * @li groups: Number of blocked connections from input channels to output channels. - * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. - * @li dilations: A tuple/list of 6 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] + * @li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of the input, now only support [1,1,1,1,1] *@par Outputs: * y: A Tensor. Has the same type as filter,and has same format as input_size @@ -719,29 +726,27 @@ REG_OP(Conv3DBackpropInput) *@brief Computes the gradients of convolution 3d with respect to the input. *@par Inputs: * Two inputs: - * @li filter: A Tensor. Types is float16. + * @li filter: A Tensor whose type is float16. * @li out_backprop: A Tensor. Must have the same type as filter. *@par Required Attributes: - *@li strides: A list of 5 ints. Specifies the stride of the sliding window for - each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". - *@li pads: A list of 6 ints. Supports only padding along the D, H and W - dimensions in sequence of head, tail, top, bottom, left and right. - *@li input_size: A Tensor of type int32, int64. An integer vector representing the shape of input, + * @li strides: A list of 5 integers. Specifies the stride of the sliding window for + * each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". + * @li pads: A list of 6 integers. Supports only padding along the D, H and W + * dimensions in sequence of head, tail, top, bottom, left and right. + * @li input_size: A tuple/list of type int32, int64. An integer vector representing the shape of input, * where input is a 5-D tensor [batch, depth, height, width, channels] or [batch, channels, depth, height, width]. *@par Attributes: * Three attributes: * @li groups: Number of blocked connections from input channels to output channels. - * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] *@par Outputs: * y: A Tensor. Has the same type as filter *@par Third-party framework compatibility * Compatible with Tensorflow's conv3d_backprop_input */ - - REG_OP(Conv3DBackpropInputD) .INPUT(filter, TensorType({DT_FLOAT16})) .INPUT(out_backprop, TensorType({DT_FLOAT16})) @@ -754,6 +759,32 @@ REG_OP(Conv3DBackpropInputD) .ATTR(data_format, String, "NDHWC") .OP_END_FACTORY_REG(Conv3DBackpropInputD) +/** +*@brief Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.. + +*@par Inputs: +* @li x: A Tensor dtype of float16. +* @li cont: A Tensor dtype of float16, float32. +* @li w_x: A Tensor dtype of float16. +* @li bias: A Tensor dtype of int16, int32, float16, float32. +* @li w_h: A Tensor dtype of float16. +* @li x_static: A optinal Tensor dtype of float16. +* @li h_0: A optinal Tensor dtype of float16, float32. +* @li c_0: A optinal Tensor dtype of float16, float32. +* @li w_x_static: A optinal Tensor dtype of float16. + +*@par Attributes: +*@li num_output: A Scalar of output size dtype of int. +*@li expose_hidden: A Scalar(bool) of features hidden. + +*@par Outputs: +*@li h: A Tensor dtype of float16, float32. +* @li h_t: A optinal Tensor dtype of float16, float32. The hidden state at time t. +* @li c_t: A optinal Tensor dtype of float16, float32. The cell state at time t. + +*@par Third-party framework compatibility: +* Compatible with the Pytorch operator adds. +*/ REG_OP(LSTM) .INPUT(x, TensorType({DT_FLOAT16})) .INPUT(cont, TensorType({DT_FLOAT32,DT_FLOAT16})) @@ -775,20 +806,25 @@ REG_OP(LSTM) *@brief Computes the gradients of convolution3D with respect to the filter *@par Inputs: * Three inputs: - * @li x: A Tensor. Must be one of the following types: float16 + * @li x: A Tensor. Must be one of the following types: float16, float32, double. * 5-D with shape [batch, in_depth, in_height, in_width, in_channels] or [batch, in_depth, in_channels, in_height, in_width]. * @li filter_size: A Tensor of type int32. An integer vector representing the tensor shape of filter, * where filter is a 5-D tensor [filter_depth, filter_height, filter_width, in_channels, out_channels] * or [out_channels, filter_depth, filter_height, filter_width, in_channels] or [out_channels, filter_depth, in_channel, filter_height, filter_width]. * @li out_backprop: A Tensor. Must have the same type as x. 5-D with shape [batch, out_depth, out_height, out_width, out_channels] * or [batch, out_depth, out_channels, out_height, out_width]. Gradients with respect to the output of the convolution. + +*@par Required Attributes: + * @li strides: A tuple/list of 5 integers. Specifies the stride of the sliding window for + * each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". + * @li pads: A tuple/list of 6 integers, [front, back, top, bottom, left, right] pads on feature map. + *@par Attributes: * Three attributes: - * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. - * @li pads: A tuple/list of 6 integers, [front, back, top, bottom, left, right] pads on feature map. * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1]. * @li groups: Number of blocked connections from input channels to output channels. - * @li data_format: An optional string from: "NDHWC", "NDCHW". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. + *@par Outputs: * y: A Tensor. Has the same type as x *@par Third-party framework compatibility @@ -810,20 +846,25 @@ REG_OP(Conv3DBackpropFilter) *@brief Computes the gradients of convolution with respect to the filter. *@par Inputs: * Two inputs: - * @li x: A Tensor. Type is float16. + * @li x: A Tensor of type float16. * 5-D with shape [batch, in_depth, in_height, in_width, in_channels] or [batch, in_depth, in_channels, in_height, in_width]. * @li out_backprop: A Tensor. Must have the same type as x. 5-D with shape [batch, out_depth, out_height, out_width, out_channels] * or [batch, out_depth, out_channels, out_height, out_width]. Gradients with respect to the output of the convolution. -*@par Attributes: - * Four attributes: - * @li filter_size: A Tensor of type integers. An integer vector representing the tensor shape of filter, + +*@par Required Attributes: + * @li filter_size: A tuple/list of type integers. An integer vector representing the tensor shape of filter, * where filter is a 5-D tensor [filter_depth, filter_height, filter_width, in_channels, out_channels] * or [out_channels, filter_depth, filter_height, filter_width, in_channels] or [out_channels, filter_depth, in_channel, filter_height, filter_width]. - * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. + * @li strides: A tuple/list of 5 integers. Specifies the stride of the sliding window for each dimension of "x". + * The N and C dimensions must be 1. Has the same format as "x". * @li pads: A tuple/list of 6 integers, [front, back, top, bottom, left, right] pads on feature map + +*@par Attributes: + * Three attributes: * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1]. * @li groups: Number of blocked connections from input channels to output channels. - * @li data_format: An optional string from: "NDHWC", "NDCHW". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. + *@par Outputs: * y: A Tensor. Has the same type as x *@par Third-party framework compatibility @@ -846,21 +887,26 @@ REG_OP(Conv3DBackpropFilterD) /** *@brief Computes the transpose of convolution 3d with respect to the input. *@par Inputs: - * Five inputs: + * Three inputs: * @li input_size: A Tensor of type int32. An integer vector representing the shape of input - * @li x: A Tensor. - * @li filter: A Tensor. Types is float16. + * @li x: A Tensor of type float16, currently does not support int8 + * @li filter: A Tensor of type float16. + +*@par Optional input: + * Two optional inputs * @li bias: An optional 1D tensor of the same type as "x". * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved. *@par Required Attributes: - * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. + * @li strides: A tuple/list of 5 integers. Specifies the stride of the sliding window for each dimension of "x". + * The N and C dimensions must be 1. Has the same format as "x". * @li pads: A tuple/list of 6 integers + *@par Attributes: * Five attributes: * @li groups: Number of blocked connections from input channels to output channels. * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] - * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. * @li output_padding: The size will be added in the output shape. * @li offset_x: Input offset_x value *@par Outputs: @@ -868,11 +914,11 @@ REG_OP(Conv3DBackpropFilterD) */ REG_OP(Conv3DTranspose) .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) - .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) - .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(x, TensorType({DT_FLOAT16})) + .INPUT(filter, TensorType({DT_FLOAT16})) .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16})) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) @@ -885,28 +931,29 @@ REG_OP(Conv3DTranspose) /** *@brief Computes the transpose of convolution 3d with respect to the input. *@par Inputs: - * Four inputs: - * @li x: A Tensor. - * @li filter: A Tensor. Types is float16. + * @li x: A Tensor of type float16. + * @li filter: A Tensor of type float16. + +*@par Optional inputs: * @li bias: An optional 1D tensor of the same type as "x". * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved. *@par Required Attributes: - * @li input_size: A Tensor of type int32. An integer vector representing the shape of input - * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. - * @li pads: A tuple/list of 6 integers + * @li input_size: A tuple/list of type int32. An integer vector representing the shape of input + * @li strides: A tuple/list of 5 integers. Specifies the stride of the sliding window for each dimension of "x". + * The N and C dimensions must be 1. Has the same format as "x". + * @li pads: A tuple/list of 6 integers. + *@par Attributes: * Five attributes: * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] * @li groups: Number of blocked connections from input channels to output channels. - * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. * @li output_padding: The size will be added in the output shape. * @li offset_x: Input offset_x value *@par Outputs: * y: A Tensor. Has the same type as filter */ - - REG_OP(Conv3DTransposeD) .INPUT(x, TensorType({DT_FLOAT16})) .INPUT(filter, TensorType({DT_FLOAT16})) @@ -923,5 +970,97 @@ REG_OP(Conv3DTransposeD) .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Conv3DTransposeD) +/** +*@brief Computes the transpose of convolution 2d with respect to the input. +*@par Inputs: + * Five inputs: + * @li input_size: A Tensor of type int32 or int64. An integer vector representing + * the shape of input. + * @li x: A Tensor of type float16, int8. + * @li filter: A Tensor of type float16, int8. Must have the same type as "x". + * @li bias: An optional 1D tensor of the same type as "x". + * @li offset_w: An optional 1D tensor for quantized inference. Reserved. +*@par Required Attributes: + * @li strides: A required list or tuple. The stride of the sliding window for + * height and width for H/W dimension. + * @li pads: A required list or tuple of int32. Padding added to each dimension + * of the input. +*@par Attributes: + * Five attributes: + * @li groups: Number of blocked connections from input channels to output channels. + * Defaults to "1". + * @li dilations: A tuple/list of 4 integers, The dilation factor for each dimension + * of input. Must be [1, 1, 1, 1]. + * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to "NHWC". + * Specify the data format of the input and output data. + * @li output_padding: The size will be added in the output shape. Defaults + * to [0, 0, 0, 0]. + * @li offset_x: An optional int. Input offset, used for quantized inference. + * Defaults to "0". +*@par Outputs: + * y: A Tensor. Has the same type as "filter". +*/ +REG_OP(Conv2DTranspose) + .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) + .INPUT(x, TensorType({DT_FLOAT16, DT_INT8})) + .INPUT(filter, TensorType({DT_FLOAT16, DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dilations, ListInt, {1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NHWC") + .ATTR(output_padding, ListInt, {0, 0, 0, 0}) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(Conv2DTranspose) + +/** +*@brief Computes the transpose of convolution 2d with respect to the input. +*@par Inputs: + * Four inputs: + * @li x: A Tensor of type float16, int8. + * @li filter: A Tensor of type float16, int8. Must have the same type as "x". + * @li bias: An optional 1D tensor of the same type as "x". + * @li offset_w: An optional 1D tensor for quantized inference. Type is int8. Reserved. +*@par Required Attributes: + * @li input_size: A Tensor of type int32 or int64. An integer vector representing the + * shape of input. + * @li strides: A required list or tuple. The stride of the sliding window for + * height and width for H/W dimension. + * @li pads: A required list or tuple of int32. Padding added to each dimension + * of the input. +*@par Attributes: + * Five attributes: + * @li groups: Number of blocked connections from input channels to output channels. + * Defaults to "1". + * @li dilations: A tuple/list of 4 integers, The dilation factor for each dimension + * of input. Must be [1, 1, 1, 1]. + * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to "NHWC". + * Specify the data format of the input and output data. + * @li output_padding: The size will be added in the output shape. Defaults + * to [0, 0, 0, 0]. + * @li offset_x: An optional int. Input offset, used for quantized inference. + * Defaults to "0". +*@par Outputs: + * y: A Tensor. Has the same type as "filter". +*/ +REG_OP(Conv2DTransposeD) + .INPUT(x, TensorType({DT_FLOAT16, DT_INT8})) + .INPUT(filter, TensorType({DT_FLOAT16, DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) + .REQUIRED_ATTR(input_size, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dilations, ListInt, {1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NHWC") + .ATTR(output_padding, ListInt, {0, 0, 0, 0}) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(Conv2DTransposeD) + } // namespace ge #endif // GE_OP_NN_CALCULATION_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_detect_ops.h b/third_party/fwkacllib/inc/ops/nn_detect_ops.h index 0a91e237..9a17cd0d 100644 --- a/third_party/fwkacllib/inc/ops/nn_detect_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_detect_ops.h @@ -909,20 +909,30 @@ REG_OP(DecodeBbox) /** *@brief Computes ClipBoxes function. -* + *@par Inputs: -*Inputs include: -* @li boxes_input: A Tensor. Must be float16. N-D with shape [N, 4]. -* @li img_size: A Tensor. Must be int32. shape [H, W]. -* +*@li boxes_input: A Tensor. Must be float16. N-D with shape [N, 4]. +*@li img_size: A Tensor. Must be int32. shape [H, W]. + *@par Outputs: -* @ boxes_output: A Tensor. Must have the same type as boxes_output. N-D with shape [N, 4]. +*boxes_output: A Tensor. Must have the same type as boxes_output. N-D with shape [N, 4]. */ REG_OP(ClipBoxes) .INPUT(boxes_input, TensorType({DT_FLOAT16})) .INPUT(img_size, TensorType({DT_INT32})) .OUTPUT(boxes_output, TensorType({DT_FLOAT16})) .OP_END_FACTORY_REG(ClipBoxes) + +/** +*@brief Computes ClipBoxesD function. + +*@par Inputs: +*@li boxes_input: A Tensor. Must be float16. N-D with shape [N, 4]. +*@li img_size: A Tensor. Must be int32. shape [H, W]. + +*@par Outputs: +*boxes_output: A Tensor. Must have the same type as boxes_output. N-D with shape [N, 4]. +*/ REG_OP(ClipBoxesD) .INPUT(boxes_input, TensorType({DT_FLOAT16})) .REQUIRED_ATTR(img_size, ListInt) @@ -959,13 +969,13 @@ REG_OP(FastrcnnPredictions) /** *@brief Computes Fastrcnn RpnProposals function. -* + *@par Inputs: *Inputs include: * @li rois: A Tensor. Must be float16. N-D with shape [N, 4]. * @li cls_bg_prob: A Tensor. Must be float16. N-D with shape [N, 1]. * @li img_size: A Tensor. Must be int32. shape [H, W]. -* + *@par Attributes: * @li score_threshold: required, float, threahold of topk process. * @li k: required, Int, threahold of topk process. @@ -975,10 +985,14 @@ REG_OP(FastrcnnPredictions) * @li score_filter: bool, mark of score_filter. Defaults to "true" * @li box_filter: bool, mark of box_filter. Defaults to "true" * @li score_sigmoid: bool, mark of score_sigmoid. Defaults to "false" + *@par Outputs: * @li sorted_rois: A Tensor. Must be float16. N-D with shape [N, 4]. * @li sorted_scores: A Tensor. Must be float16. N-D with shape [N, 1]. * @li sorted_classes: A Tensor. Must be float16. N-D with shape [N, 1]. + +* @par Third-party framework compatibility +* Compatible with the TensorFlow operator Unpack. */ REG_OP(RpnProposals) .INPUT(rois, TensorType({DT_FLOAT16})) @@ -994,6 +1008,31 @@ REG_OP(RpnProposals) .ATTR(score_sigmoid, Bool, false) .OUTPUT(sorted_box, TensorType({DT_FLOAT16})) .OP_END_FACTORY_REG(RpnProposals) + +/** +*@brief Computes Fastrcnn RpnProposalsD function. + +*@par Inputs: +*@li rois: A Tensor. Must be float16. N-D with shape [N, 4]. +*@li cls_bg_prob: A Tensor. Must be float16. N-D with shape [N, 1]. + +*@par Attributes: +*@li img_size: A Tensor size of image. Must be int32. shape [H, W]. +*@li score_threshold: required, float, threahold of topk process. +*@li k: required, Int, threahold of topk process. +*@li min_size: required, float, threahold of nms process. +*@li nms_threshold: required, float, threahold of nms process. +*@li post_nms_num: required, float, threahold of nms process. +*@li score_filter: bool, mark of score_filter. Defaults to "true" +*@li box_filter: bool, mark of box_filter. Defaults to "true" +*@li score_sigmoid: bool, mark of score_sigmoid. Defaults to "false" + +*@par Outputs: +*sorted_box: A Tensor of output. Must be float16. N-D with shape [N, 1]. + +* @par Third-party framework compatibility +* Compatible with the pytorch operator RPNProposals. +*/ REG_OP(RpnProposalsD) .INPUT(rois, TensorType({DT_FLOAT16})) .INPUT(cls_bg_prob, TensorType({DT_FLOAT16})) @@ -1165,17 +1204,17 @@ REG_OP(DecodeWheelsTarget) *@li iou_threshold: A required attribute of type float32, specifying the nms iou iou_threshold. *@li max_size_per_class: A required attribute of type int, specifying the nms output num per class. *@li max_total_size: A required attribute of type int, specifying the the nms output num per batch. -*@li change_coordinate_frame: A required attribute of type bool, whether to normalize coordinates after clipping. -*@li transpose_box: A required attribute of type bool, whether inserted transpose before this op. +*@li change_coordinate_frame: A optional attribute of type bool, whether to normalize coordinates after clipping. +*@li transpose_box: A optional attribute of type bool, whether inserted transpose before this op. must be "false". *@par Outputs: *@li nmsed_boxes: A 3D Tensor of type float16 with shape (batch, max_total_size, 4), * specifying the output nms boxes per batch. -*@li nmsed_scores: A 2D Tensor of type float16 with shape (N, 4), +*@li nmsed_scores: A 2D Tensor of type float16 with shape (batch, max_total_size), * specifying the output nms score per batch. -*@li nmsed_classes: A 2D Tensor of type float16 with shape (N, 4), +*@li nmsed_classes: A 2D Tensor of type float16 with shape (batch, max_total_size), * specifying the output nms class per batch. -*@li nmsed_num: A 1D Tensor of type float16 with shape (N, 4), specifying the valid num of nmsed_boxes. +*@li nmsed_num: A 1D Tensor of type int32 with shape (batch), specifying the valid num of nmsed_boxes. *@attention Constraints: * Only computation of float16 data is supported. @@ -1191,8 +1230,8 @@ REG_OP(BatchMultiClassNonMaxSuppression) .OUTPUT(nmsed_num, TensorType({DT_INT32})) .REQUIRED_ATTR(score_threshold, Float) .REQUIRED_ATTR(iou_threshold, Float) - .REQUIRED_ATTR(max_size_per_class, Float) - .REQUIRED_ATTR(max_total_size, Float) + .REQUIRED_ATTR(max_size_per_class, Int) + .REQUIRED_ATTR(max_total_size, Int) .ATTR(change_coordinate_frame, Bool, false) .ATTR(transpose_box, Bool, false) .OP_END_FACTORY_REG(BatchMultiClassNonMaxSuppression) diff --git a/third_party/fwkacllib/inc/ops/nn_norm_ops.h b/third_party/fwkacllib/inc/ops/nn_norm_ops.h index f5b20cdd..52e7702c 100644 --- a/third_party/fwkacllib/inc/ops/nn_norm_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_norm_ops.h @@ -338,6 +338,21 @@ REG_OP(ConfusionSoftmaxGrad) .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) .OP_END_FACTORY_REG(ConfusionSoftmaxGrad) +/** +*@brief Function softmax gradients ext. + +*@par Inputs: +* @li grad: A Tensor dtype of float16. +* @li x1: A Tensor dtype of float16, float32. +* @li x2: A Tensor dtype of float16. + +*@par Attributes: +*@li axis: A int Scalar. The axis for reduce. +*@li keepdims: A bool Scalar. If true, retains reduced dimensions with length 1. + +*@par Outputs: +*y: A Tensor dtype of float16, float32. +*/ REG_OP(SoftmaxGradExt) .INPUT(grad, TensorType({DT_FLOAT16,DT_FLOAT})) .INPUT(x1, TensorType({DT_FLOAT16,DT_FLOAT})) @@ -346,7 +361,7 @@ REG_OP(SoftmaxGradExt) .ATTR(axes, Int, 1) .ATTR(keep_dims, Bool, false) .OP_END_FACTORY_REG(SoftmaxGradExt) - + /** *@brief Normalizes the input. @@ -860,6 +875,23 @@ REG_OP(InstanceNormV2) .ATTR(epsilon, Float, 0.00001) .OP_END_FACTORY_REG(InstanceNormV2) +REG_OP(INInferV2D) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) + .OPTIONAL_INPUT(gamma, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(beta, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance_sqrt, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INInferV2D) + +REG_OP(InHost) + .INPUT(variance, TensorType({DT_FLOAT})) + .OUTPUT(variance_sqrt, TensorType({DT_FLOAT})) + .ATTR(epsilon, Float, 0.00001) + .OP_END_FACTORY_REG(InHost) } // namespace ge #endif //GE_OP_NN_NORM_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h index 98c4b246..a7d4c6e3 100644 --- a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h @@ -219,6 +219,39 @@ REG_OP(MaxPool) .ATTR(data_format, String, "NHWC") .OP_END_FACTORY_REG(MaxPool) +/** +*@brief Performs max 3d pooling on the input. + +*@par Inputs: +*x: An NC1HWC0 Tensor. Supported type:float16, float32, double, int8, int16, \n +int32, int64, uint8, uint16, qint8 + +*@par Attributes: +*@li ksize: A required list of int8, int16, int32, or int64 values, \n +specifying the size of the window for each dimension of the input tensor. \n +No default value. +*@li strides: A required list of int8, int16, int32, or int64 values, \n +specifying the stride of the sliding window for each dimension of \n +the input tensor. No default value. +*@li padding: A required string. No default value. +*@li pads: A list type of int32. Default value {0, 0, 0, 0, 0, 0}. +*@li dilation: A list type of int32. Default value {0,0,0}. +*@li ceil_mode: A ceil mode number of int32 . Default value 0. +*@li data_format: An optional string. Defaults to "NHWC". + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". + +*@attention Constraints: +*@li "ksize" is a list that has length 4: ksize[0] = 1 or ksize[3] = 1, + * ksize[1] * ksize[2] <= 255. +*@li "stride is a list that has length 4: strides[0] = 1 or strides[3] = 1, + * strides[1] <= 63, strides[0] >= 1, strides[2] <= 63, strides[2] >= 1. +*@li "padding" is either "SAME" or "VALID". + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator MaxPool3D. +*/ REG_OP(MaxPool3D) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) diff --git a/third_party/fwkacllib/inc/ops/nn_training_ops.h b/third_party/fwkacllib/inc/ops/nn_training_ops.h index cc17103c..0ecaf9a3 100644 --- a/third_party/fwkacllib/inc/ops/nn_training_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_training_ops.h @@ -1480,20 +1480,21 @@ REG_OP(ApplyProximalAdagradD) *@par Inputs: * Seven inputs, including:\n -* @li var: A mutable Tensor. +* @li var: A mutable Tensor.\n * TensorType::NumberType(). Should be a Variable Tensor. -* @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. -* @li lr: A Tensor of the same type as "var". -* Scaling factor. Must be a scalar. -* @li l1: A Tensor of the same type as "var". -* L1 regulariation. Must be a scalar. -* @li l2: A Tensor of the same type as "var". -* L2 regulariation. Must be a scalar. -* @li grad: A Tensor. Has the same type as "var". +* @li accum: A mutable Tensor of the same type as "var".\n +* Should be a Variable Tensor. Should be greater than or equal to zero.\n +* Accum and grad cannot be equal to zero at the same time. +* @li lr: A Tensor of the same type as "var".\n +* Scaling factor. Must be a scalar. Should be greater than zero. +* @li l1: A Tensor of the same type as "var".\n +* L1 regulariation. Must be a scalar. Should be greater than or equal to zero. +* @li l2: A Tensor of the same type as "var".\n +* L2 regulariation. Must be a scalar. Should be greater than or equal to zero. +* @li grad: A Tensor. Has the same type as "var".\n * The gradient. -* @li indices: A vector of indices into the first dimension of "var" and "accum". -* TensorType::IndexNumberType(). +* @li indices: A vector of indices into the first dimension of "var" and "accum".\n +* TensorType::IndexNumberType(). Can contain duplicate values. *@par Attributes: *use_locking: An optional bool. Defaults to "False".\n @@ -1528,17 +1529,18 @@ REG_OP(SparseApplyProximalAdagrad) * @li var: A mutable Tensor.\n * TensorType::NumberType(). Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var".\n -* Should be a Variable Tensor. +* Should be a Variable Tensor. Should be greater than or equal to zero.\n +* Accum and grad cannot be equal to zero at the same time. * @li lr: A Tensor of the same type as "var".\n -* Scaling factor. Must be a scalar. +* Scaling factor. Must be a scalar. Should be greater than zero. * @li l1: A Tensor of the same type as "var".\n -* L1 regulariation. Must be a scalar. +* L1 regulariation. Must be a scalar. Should be greater than or equal to zero. * @li l2: A Tensor of the same type as "var".\n -* L2 regulariation. Must be a scalar. +* L2 regulariation. Must be a scalar. Should be greater than or equal to zero. * @li grad: A Tensor. Has the same type as "var". \n * The gradient. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n -* TensorType::IndexNumberType(). +* TensorType::IndexNumberType(). Can contain duplicate values. *@par Attributes: *use_locking: An optional bool. Defaults to "False".\n @@ -2113,11 +2115,12 @@ REG_OP(LarsV2Update) * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). * Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. The value of accum must be greater than 0. * @li linear: A mutable Tensor of the same type as "var". * Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. * @li indices: A vector of indices into the first dimension of var and accum. +* The value of indices must be unique. Otherwise, the result is unpredictable. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. @@ -2157,11 +2160,12 @@ REG_OP(SparseApplyFtrl) * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). * Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. The value of accum must be greater than 0. * @li linear: A mutable Tensor of the same type as "var". * Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. * @li indices: A vector of indices into the first dimension of var and accum. +* The value of indices must be unique. Otherwise, the result is unpredictable. * @par Attributes: * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. diff --git a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h index a01073cf..310325c8 100644 --- a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h +++ b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h @@ -531,6 +531,19 @@ REG_OP(LeakyReluGrad) .OUTPUT(backprops, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OP_END_FACTORY_REG(LeakyReluGrad) +/** +*@brief Thresholds grad each element of the input Tensor. + +*@par Inputs: +* @li gradients: A Tensor shape and dtype of input gradients. Support float16, float32, int8, uint8, int32. +* @li features: A Tensor shape and dtype of input features. Support float16, float32, int8, uint8, int32. + +*@par Attributes: +*threshold: A float32 scale value to threshold at. + +*@par Outputs: +*backprops: A Tensor of shape and dtype of output backprops, should be same shape and type as inputs. +*/ REG_OP(ThresholdGradV2D) .INPUT(gradients, TensorType({DT_INT32, DT_FLOAT16})) .INPUT(features, TensorType({DT_INT32, DT_FLOAT16})) @@ -538,6 +551,19 @@ REG_OP(ThresholdGradV2D) .REQUIRED_ATTR(threshold, Float) .OP_END_FACTORY_REG(ThresholdGradV2D) +/** +*@brief Thresholds each element of the input Tensor y = (x > threshold) ? x : value. + +*@par Inputs: +*x: A Tensor dtype of float16, float32, int8, uint8, int32. + +*@par Attributes: +*@li threshold: A float32 scale value to threshold at. +*@li value: A float32 scale value to replace with. + +*@par Outputs: +*y: A Tensor of shape and dtype of output, should be same shape and type as input. +*/ REG_OP(ThresholdV2D) .INPUT(x, TensorType::RealNumberType()) .OUTPUT(y, TensorType::RealNumberType()) @@ -545,6 +571,25 @@ REG_OP(ThresholdV2D) .REQUIRED_ATTR(value, Float) .OP_END_FACTORY_REG(ThresholdV2D) +/** +*@brief: Computes hyperbolic tangent of "x" element-wise. + +*@par Inputs: +*One input: +*x: A Tensor. Must be one of the following types: float16, float32. + +*@par Outputs: +*y: A Tensor. Has the same type as "x". + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Mish. +*/ + +REG_OP(Mish) + .INPUT(x, TensorType({ DT_FLOAT,DT_FLOAT16 })) + .OUTPUT(y, TensorType({ DT_FLOAT,DT_FLOAT16 })) + .OP_END_FACTORY_REG(Mish) + } // namespace ge #endif // GE_OP_NONLINEAR_FUC_OPS_H diff --git a/third_party/fwkacllib/inc/ops/npu_loss_scale_ops.h b/third_party/fwkacllib/inc/ops/npu_loss_scale_ops.h index 1c702738..8e9e1638 100644 --- a/third_party/fwkacllib/inc/ops/npu_loss_scale_ops.h +++ b/third_party/fwkacllib/inc/ops/npu_loss_scale_ops.h @@ -19,15 +19,40 @@ #include "graph/operator_reg.h" namespace ge { + +/** +*@brief Computes NPU alloc float status operator function. + +*@par Outputs: +*data: A Tensor of data value. Must be float32. +*/ REG_OP(NPUAllocFloatStatusOperator) .OUTPUT(data, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(NPUAllocFloatStatusOperator) +/** +*@brief Computes NPU clear float status operator function. + +*@par Inputs: +*addr: A Tensor of data memory address. Must be float32. + +*@par Outputs: +*data: A Tensor of data value. Must be float32. +*/ REG_OP(NPUClearFloatStatusOperator) .INPUT(addr, TensorType{DT_FLOAT}) .OUTPUT(data, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(NPUClearFloatStatusOperator) +/** +*@brief Computes NPU get float status operator function. + +*@par Inputs: +*addr: A Tensor of data memory address. Must be float32. + +*@par Outputs: +*data: A Tensor of data value. Must be float32. +*/ REG_OP(NPUGetFloatStatusOperator) .INPUT(addr, TensorType{DT_FLOAT}) .OUTPUT(data, TensorType({DT_FLOAT})) @@ -47,7 +72,7 @@ REG_OP(NPUAllocFloatStatus) *@brief Set the value of address 0x40000 to 0 in each core. *@par Inputs: -*@li addr: A tensor of type float32. +*addr: A tensor of type float32. *@par Outputs: *data: A Tensor of type float32. @@ -61,7 +86,7 @@ REG_OP(NPUClearFloatStatus) *@brief Get the value of address 0x40000. *@par Inputs: -*@li addr: A tensor of type float32. +*addr: A tensor of type float32. *@par Outputs: *data: A Tensor of type float32. diff --git a/third_party/fwkacllib/inc/ops/pad_ops.h b/third_party/fwkacllib/inc/ops/pad_ops.h index 7e9c65f4..f7153936 100644 --- a/third_party/fwkacllib/inc/ops/pad_ops.h +++ b/third_party/fwkacllib/inc/ops/pad_ops.h @@ -240,6 +240,28 @@ REG_OP(AscendPadding) .ATTR(pad_dim_size, Int, 8) .OP_END_FACTORY_REG(AscendPadding) -} // namespace ge +/** +*@brief EmbeddingRankId, traverse the index calculation server and its position in the server. + +*@par Inputs: +*One input, include: +*addr_table: Tensor which last dimension must be 3. For example: [8, 3]. +*index: Tensor For example: [640000]. +*@par Outputs: +*rank_id: Tensor the first dimension of index to Size, [size, 3]. + Tensor which last dimension must be 3.For example: [640000, 3] +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Diag. +*/ +REG_OP(EmbeddingRankId) + .INPUT(addr_table, TensorType({DT_UINT64})) + .INPUT(index, TensorType({DT_UINT32})) + .OUTPUT(rank_id, TensorType({DT_UINT64})) + .ATTR(row_memory, Int, 320) + .ATTR(mode, String, "mod") + .OP_END_FACTORY_REG(EmbeddingRankId) + + +} // namespace ge #endif //GE_OP_PAD_OPS_H diff --git a/third_party/fwkacllib/inc/ops/quantize_ops.h b/third_party/fwkacllib/inc/ops/quantize_ops.h index 4bf0e5bf..4cb80cea 100644 --- a/third_party/fwkacllib/inc/ops/quantize_ops.h +++ b/third_party/fwkacllib/inc/ops/quantize_ops.h @@ -71,6 +71,7 @@ REG_OP(Dequantize) *@par Outputs: *y: The quantized output tensor of type int8 and with format NC1HWC0. + *@par Third-party framework compatibility * It is a custom operator. It has no corresponding operator in Caffe. */ @@ -97,6 +98,7 @@ REG_OP(AscendQuant) *@par Outputs: *y: The dequantized output tensor of type float16 or float32 and with format NC1HWC0. + *@par Third-party framework compatibility * It is a custom operator. It has no corresponding operator in Caffe. */ @@ -109,6 +111,24 @@ REG_OP(AscendDequant) .ATTR(dtype, Int, DT_FLOAT) .OP_END_FACTORY_REG(AscendDequant) +/** +*@brief Anti quantizes the input. + +*@par Inputs: +*x: An NC1HWC0 tensor of type int8, specifying the input. + +*@par Attributes: +*@li scale: A required float32 scale. +*@li offset: A required float32 offset. +*@li dtype: A optional int32, specifying the output data type. Defaults to "DT_FLOAT". +*@li sqrt_mode: A optional bool, specifying whether to perform square root on "scale", either "True" or "False". Defaults to "False". + +*@par Outputs: +*y: The dequantized output tensor of type float16 or float32 and with format NC1HWC0. + +*@par Third-party framework compatibility +* It is a custom operator. It has no corresponding operator in Caffe. +*/ REG_OP(AscendAntiQuant) .INPUT(x, TensorType({DT_INT8})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) @@ -118,6 +138,23 @@ REG_OP(AscendAntiQuant) .ATTR(sqrt_mode, Bool, false) .OP_END_FACTORY_REG(AscendAntiQuant) +/** +*@brief Dequantizes the input of int16. + +*@par Inputs: +*@li x0: An NC1HWC0 tensor of type int32, specifying the input. +*@li deq_scale: An NC1HWC0 tensor of type float16 or uint64, specifying the scaling ratio. +*@li x1: An NC1HWC0 tensor of type int16, specifying the input. + +*@par Attributes: +*relu_flag: A optional bool, specifying whether to perform ReLU, either "True" or "False". Defaults to "False". + +*@par Outputs: +*y: The dequantized output tensor of type float16 or float32 and with format NC1HWC0. + +*@par Third-party framework compatibility +* It is a custom operator. It has no corresponding operator in Caffe. +*/ REG_OP(AscendDequantS16) .INPUT(x0, TensorType({DT_INT32})) .INPUT(deq_scale, TensorType({DT_UINT64})) @@ -126,6 +163,22 @@ REG_OP(AscendDequantS16) .ATTR(relu_flag, Bool, false) .OP_END_FACTORY_REG(AscendDequantS16) +/** +*@brief Requantizes the input. + +*@par Inputs: +*@li x: An NC1HWC0 tensor of type int32, specifying the input. +*@li req_scale: An NC1HWC0 tensor of type uint64, specifying the scaling ratio. + +*@par Attributes: +*relu_flag: A optional bool, specifying whether to perform ReLU, either "True" or "False". Defaults to "False". + +*@par Outputs: +*y: The dequantized output tensor of type int8 and with format NC1HWC0. + +*@par Third-party framework compatibility +* It is a custom operator. It has no corresponding operator in Caffe. +*/ REG_OP(AscendRequant) .INPUT(x, TensorType({DT_INT32})) .INPUT(req_scale, TensorType({DT_UINT64})) @@ -133,6 +186,25 @@ REG_OP(AscendRequant) .ATTR(relu_flag, Bool, false) .OP_END_FACTORY_REG(AscendRequant) +/** +*@brief Requantizes the input of int16. + +*@par Inputs: +*@li x: An NC1HWC0 tensor of type int16, specifying the input. +*@li req_scale: An NC1HWC0 tensor of type uint64, specifying the scaling ratio. +*@li x1: An NC1HWC0 tensor of type int16. + +*@par Attributes: +*@li dual_output: A optional bool, specifying whether to perform dual ouput, either "True" or "False". Defaults to "False". +*@li relu_flag: A optional bool, specifying whether to perform ReLU, either "True" or "False". Defaults to "False". + +*@par Outputs: +*@li y: The dequantized output tensor of type int8 and with format NC1HWC0. +*@li y1: The dequantized output tensor of type int16 and with format NC1HWC0. + +*@par Third-party framework compatibility +* It is a custom operator. It has no corresponding operator in Caffe. +*/ REG_OP(AscendRequantS16) .INPUT(x, TensorType({DT_INT16})) .INPUT(req_scale, TensorType({DT_UINT64})) diff --git a/third_party/fwkacllib/inc/ops/reduce_ops.h b/third_party/fwkacllib/inc/ops/reduce_ops.h index 8cf9f342..d3dfefe1 100644 --- a/third_party/fwkacllib/inc/ops/reduce_ops.h +++ b/third_party/fwkacllib/inc/ops/reduce_ops.h @@ -173,8 +173,8 @@ REG_OP(BNInfer) *@brief Performs reduced batch normalization. For some scene which don't contain assignmoving average. -*@par Inputs:\n -* Five inputs, including: (NC1HWC0 supported) +*@par Inputs: +*Five inputs, including: (NC1HWC0 supported) *@li x: A 5D Tensor of type float16 or float32. *@li sum: A 5D Tensor of type float32 for the output of operator BNTrainingReduce. *@li square_sum: A 5D Tensor of type float32 for the output of operator BNTrainingReduce. @@ -184,15 +184,15 @@ assignmoving average. *@par Attributes: *epsilon: A required float32, specifying the small value added to variance to avoid dividing by zero. -*@par Outputs:\n -* Three outputs, including: (NC1HWC0 supported) +*@par Outputs: +*Three outputs, including: (NC1HWC0 supported) *@li y: A 5D Tensor of type float16 or float32, for normalized "x". *@li batch_mean: A 5D Tensor of type float32, for the mean of "x". *@li batch_variance: A 5D Tensor of type float32, for the variance of "x". *@attention Constraints: -*@li This operator is used in conjunction with BNTrainingReduce. -*@li For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. +*This operator is used in conjunction with BNTrainingReduce. \n +For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. */ REG_OP(BNTrainingUpdateV2) .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) @@ -206,6 +206,33 @@ REG_OP(BNTrainingUpdateV2) .OUTPUT(batch_variance, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(BNTrainingUpdateV2) +/** +*@brief Performs reduced batch normalization v3. For some scene which don't contain +assignmoving average. + +*@par Inputs: +* Five inputs, including: (NC1HWC0 supported) +*@li x: A 5D Tensor of type float16 or float32. +*@li sum: A 5D Tensor of type float32 for the output of operator BNTrainingReduce. +*@li square_sum: A 5D Tensor of type float32 for the output of operator BNTrainingReduce. +*@li scale: A 5D Tensor of type float32, for the scaling factor. +*@li offset: A 5D Tensor of type float32, for the scaling offset. + +*@par Attributes: +*epsilon: A required float32, specifying the small value added to variance to avoid dividing by zero. + +*@par Outputs: +* Three outputs, including: (NC1HWC0 supported) +*@li y: A 5D Tensor of type float16 or float32, for normalized "x". +*@li batch_mean: A 5D Tensor of type float32, for the mean of "x". +*@li batch_variance: A 5D Tensor of type float32, for the variance of "x". +*@li reserve_1: A 5D Tensor of type float32, for the mean of batch "x". Has the same type as batch_mean. +*@li reserve_2: A 5D Tensor of type float32, for the variance of batch "x". Has the same type as batch_mean. + +*@attention Constraints: +*@li This operator is used in conjunction with BNTrainingReduce. +*@li For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. +*/ REG_OP(BNTrainingUpdateV3) .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) .INPUT(sum, TensorType({DT_FLOAT})) diff --git a/third_party/fwkacllib/inc/ops/rnn.h b/third_party/fwkacllib/inc/ops/rnn.h index b72d9a79..ebc59a34 100644 --- a/third_party/fwkacllib/inc/ops/rnn.h +++ b/third_party/fwkacllib/inc/ops/rnn.h @@ -67,6 +67,17 @@ REG_OP(BasicLSTMCell) .ATTR(activation, String, "tanh") .OP_END_FACTORY_REG(BasicLSTMCell) +/** +*@brief: Dynamic LSTM forward calculation. + +*@par Inputs: +*@li x:A 4D Tensor. Must be the type float32. The format must be FRACTAL_NZ. +*@li w:A 4D Tensor. Must be the type float32. The format must be FRACTAL_Z. +*@li b:A 1D Tensor. Must be the type float32. The format must be ND. + +*@par Outputs: +*output_h:A Tensor of output. Must be the type float32. The format must be FRACTAL_Z. +*/ REG_OP(DynamicLSTM) .INPUT(x, TensorType({DT_FLOAT32})) .INPUT(w, TensorType({DT_FLOAT32})) diff --git a/third_party/fwkacllib/inc/ops/selection_ops.h b/third_party/fwkacllib/inc/ops/selection_ops.h index c2e6f13a..47cf4a47 100644 --- a/third_party/fwkacllib/inc/ops/selection_ops.h +++ b/third_party/fwkacllib/inc/ops/selection_ops.h @@ -449,7 +449,7 @@ REG_OP(StridedSliceGrad) *@par Inputs: *Three inputs, including: * @li x: A Tensor of type NumberType. -* @li segment_ids: A 1D Tensor of type IndexNumberType, whose shape is a prefix +* @li segment_ids: A Tensor of type IndexNumberType, whose shape is a prefix * of "x.shape". * @li num_segments: A Tensor of type IndexNumberType. @@ -472,7 +472,7 @@ REG_OP(UnsortedSegmentSum) *@par Inputs: *Two inputs, including: * @li x: A Tensor of type float16, float32, int32, int8, uint8. -* @li segment_ids: A 1D Tensor of type int32, whose shape is a prefix +* @li segment_ids: A Tensor of type int32, whose shape is a prefix * of "x.shape". *@par Attributes: @@ -794,7 +794,7 @@ REG_OP(SliceD) * @li k =< 5120 * @li Size of the last dimension =< 65500 * @li sorted = true -* @li Don't support to get score on the platform of Ascend310 +* @li It's unstable sorted indices on the platform of Ascend310 * @par Third-party framework compatibility * @li Compatible with the TensorFlow operator TopK. @@ -1417,6 +1417,54 @@ REG_OP(UnsortedSegmentMinD) .OP_END_FACTORY_REG(UnsortedSegmentMinD) /** +* @brief Computes the maximum along segments of a tensor. + +* @par Inputs: +* Three inputs, including: +* @li x: A Tensor of type RealNumberType. +* @li segment_ids: A 1D Tensor of type IndexNumberType, whose shape is a prefix +* of "x.shape". +* @li num_segments: A Tensor of type IndexNumberType. + +* @par Outputs: +* y: A Tensor of type RealNumberType. + +* @see UnsortedSegmentSum(), UnsortedSegmentProd(), + +* @par Third-party framework compatibility +* @li Compatible with the TensorFlow operator UnsortedSegmentMax. +*/ +REG_OP(UnsortedSegmentMax) + .INPUT(x, TensorType::RealNumberType()) + .INPUT(segment_ids, TensorType::IndexNumberType()) + .INPUT(num_segments, TensorType::IndexNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) + .OP_END_FACTORY_REG(UnsortedSegmentMax) + +/** +* @brief Computes the maximum along segments of a tensor. + +* @par Inputs: +* Two inputs, including: +* @li x: A Tensor of the following types:int32, int16, float16, float32. +* @li segment_ids: A 1D Tensor of type int32, whose shape is a prefix +* of "x.shape". + +* @par Attributes: +* num_segments: A required int32, specifying the number of distinct segment IDs. + +* @par Outputs: +* y: A Tensor.Must have the same type as input "x". + +* @see UnsortedSegmentProdD(), +*/ +REG_OP(UnsortedSegmentMaxD) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .INPUT(segment_ids, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .REQUIRED_ATTR(num_segments, Int) + .OP_END_FACTORY_REG(UnsortedSegmentMaxD) +/** * @brief Computes the product along segments of a tensor. * @par Inputs: diff --git a/third_party/fwkacllib/inc/ops/transformation_ops.h b/third_party/fwkacllib/inc/ops/transformation_ops.h index 5bbf1e78..ddbb1b4d 100644 --- a/third_party/fwkacllib/inc/ops/transformation_ops.h +++ b/third_party/fwkacllib/inc/ops/transformation_ops.h @@ -123,6 +123,22 @@ REG_OP(Transpose) .OUTPUT(y, TensorType::BasicType()) .OP_END_FACTORY_REG(Transpose) +/** +*@brief Doing format_transfer for various data format only \n +support NHWC/NCHW to NC1HWC0 and NC1HWC0 to NHWC/NCHW \n +NCHW to FRACTAL_Zn or FRACTAL_Zn to NCHW \n +HWCN to FRACTAL_Zn or FRACTAL_Zn to HWCN. + +*@par Inputs: +*src: A Tensor dtype of all types. + +*@par Attributes: +*@li src_format: A string source data format, can be NHWC, NCHW, FRACTAL_Zn etc. +*@li expose_hidden: A string target data format, can be NC1HWC0, NCHW, FRACTAL_Zn etc. + +*@par Outputs: +*dst: A Tensor dtype of all types. +*/ REG_OP(TransData) .INPUT(src, TensorType::BasicType()) .OUTPUT(dst, TensorType::BasicType()) @@ -505,6 +521,7 @@ REG_OP(Unpack) * @attention Constraints: * "ksizes", "strides" and "rates" are lists of integers. + * @par Third-party framework compatibility * Compatible with the TensorFlow operator ExtractImagePatches. */ diff --git a/third_party/fwkacllib/inc/runtime/base.h b/third_party/fwkacllib/inc/runtime/base.h index 2d6503f9..572053f6 100644 --- a/third_party/fwkacllib/inc/runtime/base.h +++ b/third_party/fwkacllib/inc/runtime/base.h @@ -37,47 +37,298 @@ extern "C" { * @brief runtime error numbers. */ typedef enum tagRtError { - RT_ERROR_NONE = 0x0, // success - RT_ERROR_INVALID_VALUE = 0x1, // invalid value - RT_ERROR_MEMORY_ALLOCATION = 0x2, // memory allocation fail - RT_ERROR_INVALID_RESOURCE_HANDLE = 0x3, // invalid handle - RT_ERROR_INVALID_DEVICE_POINTER = 0x4, // invalid device point - RT_ERROR_INVALID_MEMCPY_DIRECTION = 0x5, // invalid memory copy dirction - RT_ERROR_INVALID_DEVICE = 0x6, // invalid device - RT_ERROR_NO_DEVICE = 0x7, // no valid device - RT_ERROR_CMD_OCCUPY_FAILURE = 0x8, // command occpuy failure - RT_ERROR_SET_SIGNAL_FAILURE = 0x9, // set signal failure - RT_ERROR_UNSET_SIGNAL_FAILURE = 0xA, // unset signal failure - RT_ERROR_OPEN_FILE_FAILURE = 0xB, // unset signal failure - RT_ERROR_WRITE_FILE_FAILURE = 0xC, - RT_ERROR_MEMORY_ADDRESS_UNALIGNED = 0xD, - RT_ERROR_DRV_ERR = 0xE, - RT_ERROR_LOST_HEARTBEAT = 0xF, - RT_ERROR_REPORT_TIMEOUT = 0x10, - RT_ERROR_NOT_READY = 0x11, - RT_ERROR_DATA_OPERATION_FAIL = 0x12, - RT_ERROR_INVALID_L2_INSTR_SIZE = 0x13, - RT_ERROR_DEVICE_PROC_HANG_OUT = 0x14, - RT_ERROR_DEVICE_POWER_UP_FAIL = 0x15, - RT_ERROR_DEVICE_POWER_DOWN_FAIL = 0x16, - RT_ERROR_FEATURE_NOT_SUPPROT = 0x17, - RT_ERROR_KERNEL_DUPLICATE = 0x18, // register same kernel repeatly - RT_ERROR_STREAM_DUPLICATE = 0x19, // streamId Map is repeatly - RT_ERROR_STREAM_NOT_EXIST = 0x1a, // streamId is not exist - RT_ERROR_SQ_NO_EXIST_SQ_TO_REUSE = 0x1b, // no exist sq to reuse - RT_ERROR_SQID_FULL = 0x3C, - RT_ERROR_MODEL_STREAM_EXE_FAILED = 0x91, // the model stream failed - RT_ERROR_MODEL_LOAD_FAILED = 0x94, // the model stream failed - RT_ERROR_END_OF_SEQUENCE = 0x95, // end of sequence - RT_ERROR_NO_STREAM_CB_REG = 0x96, // no callback register info for stream - RT_ERROR_DATA_DUMP_LOAD_FAILED = 0x97, // data dump load info fail - RT_ERROR_CALLBACK_THREAD_UNSUBSTRIBE = 0x98, // callback thread unsubstribe - RT_ERROR_DEBUG_REGISTER_FAILED = 0x99, // debug register fail - RT_ERROR_DEBUG_UNREGISTER_FAILED = 0x9A, // debug unregister fail - RT_ERROR_GROUP_NOT_SET = 0x9B, - RT_ERROR_GROUP_NOT_CREATE = 0x9C, - RT_ERROR_RESERVED -} rtError_t; + RT_ERROR_NONE = 0x0, // success + + RT_ERROR_DEVICE_BASE = 0x07010000, + RT_ERROR_DEVICE_NULL, + RT_ERROR_DEVICE_NEW, + RT_ERROR_DEVICE_ID, + RT_ERROR_DEVICE_CHIPTYPE, + RT_ERROR_DEVICE_DEPLOY, + RT_ERROR_DEVICE_RETAIN, + RT_ERROR_DEVICE_PLATFORM, + RT_ERROR_DEVICE_LOADER, + RT_ERROR_DEVICE_LIMIT, + RT_ERROR_DEVICE_PROC_HANG_OUT, + RT_ERROR_DEVICE_POWER_UP_FAIL, + RT_ERROR_DEVICE_POWER_DOWN_FAIL, + RT_ERROR_DEVICE_INVALID, + + RT_ERROR_DRV_BASE = 0x07020000, + RT_ERROR_DRV_NULL, + RT_ERROR_DRV_NEW, + RT_ERROR_DRV_MEMORY, + RT_ERROR_DRV_INPUT, + RT_ERROR_DRV_PTRNULL, + RT_ERROR_DRV_OPEN_AICPU, + RT_ERROR_DRV_CLOSE_AICPU, + RT_ERROR_DRV_SYM_AICPU, + RT_ERROR_DRV_OPEN_TSD, + RT_ERROR_DRV_CLOSE_TSD, + RT_ERROR_DRV_SYM_TSD, + RT_ERROR_DRV_SOURCE, + RT_ERROR_DRV_REPORT, + RT_ERROR_DRV_COMMAND, + RT_ERROR_DRV_OCCUPY, + RT_ERROR_DRV_ERR, + + RT_ERROR_STREAM_BASE = 0x07030000, + RT_ERROR_STREAM_NULL, + RT_ERROR_STREAM_NEW, + RT_ERROR_STREAM_CONTEXT, + RT_ERROR_STREAM_INVALID, + RT_ERROR_STREAM_MODEL, + RT_ERROR_STREAM_FUSION, + RT_ERROR_STREAM_FULL, + RT_ERROR_STREAM_EMPTY, + RT_ERROR_STREAM_NOT_COMPLETE, + RT_ERROR_STREAM_SYNC, + RT_ERROR_STREAM_NO_CB_REG, + RT_ERROR_STREAM_DUPLICATE, + RT_ERROR_STREAM_NOT_EXIST, + RT_ERROR_SQ_NO_EXIST_SQ_TO_REUSE, + RT_ERROR_SQID_FULL, + + RT_ERROR_MODEL_BASE = 0x07040000, + RT_ERROR_MODEL_NULL, + RT_ERROR_MODEL_NEW, + RT_ERROR_MODEL_CONTEXT, + RT_ERROR_MODEL_ENDGRAPH, + RT_ERROR_MODEL_STREAM, + RT_ERROR_MODEL_EXCUTOR, + RT_ERROR_MODEL_SETUP, + RT_ERROR_MODEL_ID, + RT_ERROR_MODEL_EXE_FAILED, + RT_ERROR_END_OF_SEQUENCE, // end of sequence + + RT_ERROR_EVENT_BASE = 0x07050000, + RT_ERROR_EVENT_NULL, + RT_ERROR_EVENT_NEW, + RT_ERROR_EVENT_RECORDER_NULL, + RT_ERROR_EVENT_TIMESTAMP_INVALID, + RT_ERROR_EVENT_TIMESTAMP_REVERSAL, + RT_ERROR_EVENT_NOT_COMPLETE, + + RT_ERROR_NOTIFY_BASE = 0x07060000, + RT_ERROR_NOTIFY_NULL, + RT_ERROR_NOTIFY_NEW, + RT_ERROR_NOTIFY_TYPE, + RT_ERROR_NOTIFY_NOT_COMPLETE, + + RT_ERROR_CONTEXT_BASE = 0x07070000, + RT_ERROR_CONTEXT_NULL, + RT_ERROR_CONTEXT_NEW, + RT_ERROR_CONTEXT_DEL, + RT_ERROR_CONTEXT_DEFAULT_STREAM_NULL, + RT_ERROR_CONTEXT_ONLINE_STREAM_NULL, + + RT_ERROR_KERNEL_BASE = 0x07080000, + RT_ERROR_KERNEL_NULL, + RT_ERROR_KERNEL_NEW, + RT_ERROR_KERNEL_LOOKUP, + RT_ERROR_KERNEL_NAME, + RT_ERROR_KERNEL_TYPE, + RT_ERROR_KERNEL_OFFSET, + RT_ERROR_KERNEL_DUPLICATE, + + RT_ERROR_PROGRAM_BASE = 0x07090000, + RT_ERROR_PROGRAM_NULL, + RT_ERROR_PROGRAM_NEW, + RT_ERROR_PROGRAM_DATA, + RT_ERROR_PROGRAM_SIZE, + RT_ERROR_PROGRAM_MEM_TYPE, + RT_ERROR_PROGRAM_MACHINE_TYPE, + RT_ERROR_PROGRAM_USEOUT, + + RT_ERROR_MODULE_BASE = 0x070a0000, + RT_ERROR_MODULE_NULL, + RT_ERROR_MODULE_NEW, + + RT_ERROR_INSTANCE_BASE = 0x070b0000, + RT_ERROR_INSTANCE_NULL, + RT_ERROR_INSTANCE_NEW, + RT_ERROR_INSTANCE_VERSION, + + RT_ERROR_API_BASE = 0x070c0000, + RT_ERROR_API_NULL, + RT_ERROR_API_NEW, + + RT_ERROR_DATADUMP_BASE = 0x070d0000, + RT_ERROR_DATADUMP_NULL, + RT_ERROR_DATADUMP_NEW, + RT_ERROR_DATADUMP_TIME, + RT_ERROR_DATADUMP_FILE, + RT_ERROR_DATADUMP_ADDRESS, + RT_ERROR_DATADUMP_LOAD_FAILED, + RT_ERROR_DUMP_ADDR_SET_FAILED, + + RT_ERROR_PROF_BASE = 0x070e0000, + RT_ERROR_PROF_NULL, + RT_ERROR_PROF_NEW, + RT_ERROR_PROF_START, + RT_ERROR_PROF_DEVICE_MEM, + RT_ERROR_PROF_HOST_MEM, + RT_ERROR_PROF_SET_DIR, + RT_ERROR_PROF_OPER, + RT_ERROR_PROF_FULL, + RT_ERROR_PROF_NAME, + + RT_ERROR_PCTRACE_BASE = 0x070f0000, + RT_ERROR_PCTRACE_NULL, + RT_ERROR_PCTRACE_NEW, + RT_ERROR_PCTRACE_TIME, + RT_ERROR_PCTRACE_FILE, + + RT_ERROR_TASK_BASE = 0x07100000, + RT_ERROR_TASK_NULL, + RT_ERROR_TASK_NEW, + RT_ERROR_TASK_TYPE, + RT_ERROR_TASK_ALLOCATOR, + + RT_ERROR_COMMON_BASE = 0x07110000, + RT_ERROR_INVALID_VALUE, // RT_ERROR_INPUT_INVALID + RT_ERROR_MEMORY_ADDRESS_UNALIGNED, + RT_ERROR_SEC_HANDLE, + RT_ERROR_OS_HANDLE, + RT_ERROR_MUTEX_LOCK, + RT_ERROR_MUTEX_UNLOCK, + RT_ERROR_CALLOC, + RT_ERROR_POOL_RESOURCE, + RT_ERROR_TRANS_ARGS, + RT_ERROR_METADATA, + RT_ERROR_LOST_HEARTBEAT, + RT_ERROR_REPORT_TIMEOUT, + RT_ERROR_FEATURE_NOT_SUPPROT, + RT_ERROR_MEMORY_ALLOCATION, + RT_ERROR_MEMORY_FREE, + + RT_ERROR_DEBUG_BASE = 0x07120000, + RT_ERROR_DEBUG_NULL, + RT_ERROR_DEBUG_NEW, + RT_ERROR_DEBUG_SIGNAL, + RT_ERROR_DEBUG_OPEN, + RT_ERROR_DEBUG_WRITE, + RT_ERROR_DEBUG_REGISTER_FAILED, + RT_ERROR_DEBUG_UNREGISTER_FAILED, + + RT_ERROR_ENGINE_BASE = 0x07130000, + RT_ERROR_ENGINE_NULL, + RT_ERROR_ENGINE_NEW, + RT_ERROR_ENGINE_THREAD, + + RT_ERROR_LABEL_BASE = 0x07140000, + RT_ERROR_LABEL_NULL, + RT_ERROR_LABEL_NEW, + RT_ERROR_LABEL_CONTEXT, + RT_ERROR_LABEL_STREAM, + RT_ERROR_LABEL_MODEL, + RT_ERROR_LABEL_ALLOCATOR, + RT_ERROR_LABEL_FREE, + RT_ERROR_LABEL_SET, + RT_ERROR_LABEL_ID, + + RT_ERROR_TSFW_BASE = 0x07150000, + RT_ERROR_TSFW_UNKNOWN, + RT_ERROR_TSFW_NULL_PTR, + RT_ERROR_TSFW_ILLEGAL_AI_CORE_ID, + RT_ERROR_TSFW_ILLEGAL_PARAM, + RT_ERROR_TSFW_TASK_CMD_QUEUE_FULL, + RT_ERROR_TSFW_TASK_CMD_QUEUE_EMPTY, + RT_ERROR_TSFW_TASK_REPORT_QUEUE_FULL, + RT_ERROR_TSFW_TASK_REPORT_QUEUE_EMPTY, + RT_ERROR_TSFW_TASK_NODE_BUFF_ALL_OCCUPYED, + RT_ERROR_TSFW_TASK_NODE_BUFF_ALL_FREED, + RT_ERROR_TSFW_L2_MEM_INSUFFICIENT_SPACE, + RT_ERROR_TSFW_L2_MALLOC_FAILED, + RT_ERROR_TSFW_DMA_CHANNEL_ALL_OCCUPYED, + RT_ERROR_TSFW_MEMCPY_OP_FAILED, + RT_ERROR_TSFW_BS_SLOT_ALL_OCCUPYED, + RT_ERROR_TSFW_TBS_SLOT_REPEAT_FREE, + RT_ERROR_TSFW_PRIORITY_TASK_LIST_FULL, + RT_ERROR_TSFW_PRIORITY_TASK_LIST_EMPTY, + RT_ERROR_TSFW_NO_STREAM_LIST_NEED_TO_BE_PROCESSED, + RT_ERROR_TSFW_REPEAT_MARK_STREAM_NEED_SERVICE, + RT_ERROR_TSFW_SYS_DMA_CHANNEL_ALL_OCCUPAPYED, + RT_ERROR_TSFW_NO_HBML2TASKNODE_FOUND, + RT_ERROR_TSFW_SQNODE_NODE_SLOT_ALL_OCCUPAPYED, + RT_ERROR_TSFW_CQNODE_NODE_SLOT_ALL_OCCUPAPYED, + RT_ERROR_TSFW_SQNODE_NOT_ENOUGH, + RT_ERROR_TSFW_SQNODE_SLOT_REPEAT_FREE, + RT_ERROR_TSFW_CQNODE_SLOT_REPEAT_FREE, + RT_ERROR_TSFW_CQ_REPORT_FAILED, + RT_ERROR_TSFW_SYS_DMA_RESET_SUCCESS, + RT_ERROR_TSFW_SYS_DMA_RESET_FAILED, + RT_ERROR_TSFW_SYS_DMA_TRNSFER_FAILED, + RT_ERROR_TSFW_SYS_DMA_MEMADDRALIGN_FAILED, + RT_ERROR_TSFW_SYS_DMA_ERROR_QUEUE_FULL, + RT_ERROR_TSFW_SYS_DMA_ERROR_QUEUE_EMPTY, + RT_ERROR_TSFW_TIMER_EVENT_FULL, + RT_ERROR_TSFW_TASK_L2_DESC_ENTRY_NOT_ENOUGH, + RT_ERROR_TSFW_AICORE_TIMEOUT, + RT_ERROR_TSFW_AICORE_EXCEPTION, + RT_ERROR_TSFW_AICORE_TRAP_EXCEPTION, + RT_ERROR_TSFW_AICPU_TIMEOUT, + RT_ERROR_TSFW_SDMA_L2_TO_DDR_MALLOC_FAIL, + RT_ERROR_TSFW_AICPU_EXCEPTION, + RT_ERROR_TSFW_AICPU_DATADUMP_RSP_ERR, + RT_ERROR_TSFW_AICPU_MODEL_RSP_ERR, + RT_ERROR_TSFW_REPEAT_ACTIVE_MODEL_STREAM, + RT_ERROR_TSFW_REPEAT_NOTIFY_WAIT, + RT_ERROR_TSFW_DEBUG_INVALID_SQCQ, + RT_ERROR_TSFW_DEBUG_WRONG_COMMAND_TYPE, + RT_ERROR_TSFW_DEBUG_CMD_PROCESS, + RT_ERROR_TSFW_DEBUG_INVALID_DEVICE_STATUS, + RT_ERROR_TSFW_DEBUG_NOT_IN_DEBUG_STATUS, + RT_ERROR_TSFW_DEBUG_INVALID_TASK_STATUS, + RT_ERROR_TSFW_DEBUG_TASK_EMPTY, + RT_ERROR_TSFW_DEBUG_TASK_FULL, + RT_ERROR_TSFW_DEBUG_TASK_NOT_EXIST, + RT_ERROR_TSFW_DEBUG_AI_CORE_FULL, + RT_ERROR_TSFW_DEBUG_AI_CORE_NOT_EXIST, + RT_ERROR_TSFW_DEBUG_AI_CORE_EXCEPTION, + RT_ERROR_TSFW_DEBUG_AI_CORE_TIMEOUT, + RT_ERROR_TSFW_DEBUG_BREAKPOINT_FULL, + RT_ERROR_TSFW_DEBUG_READ_ERROR, + RT_ERROR_TSFW_DEBUG_WRITE_FAIL, + RT_ERROR_TSFW_QUEUE_FULL, + RT_ERROR_TSFW_QUEUE_EMPTY, + RT_ERROR_TSFW_QUEUE_ALLOC_MEM_FAIL, + RT_ERROR_TSFW_QUEUE_DATA_SIZE_UNMATCH, + RT_ERROR_TSFW_PCIE_DMA_INVLD_CPY_TYPE, + RT_ERROR_TSFW_INVLD_CPY_DIR, + RT_ERROR_TSFW_PCIE_DMA_INVLD_CQ_DES, + RT_ERROR_TSFW_PCIE_DMA_CPY_ERR, + RT_ERROR_TSFW_PCIE_DMA_LNK_CHN_BUSY, + RT_ERROR_TSFW_PROFILE_BUFF_FULL, + RT_ERROR_TSFW_PROFILE_MODE_CONFLICT, + RT_ERROR_TSFW_PROFILE_OTHER_PID_ON, + RT_ERROR_TSFW_SCHD_AIC_TASK_PRELOAD_FAILED, + RT_ERROR_TSFW_TSCPU_CLOSE_FAILED, + RT_ERROR_TSFW_EXPECT_FAIL, + RT_ERROR_TSFW_REPEAT_MODEL_STREAM, + RT_ERROR_TSFW_STREAM_MODEL_UNBIND, + RT_ERROR_TSFW_MODEL_EXE_FAILED, + RT_ERROR_TSFW_IPC_SEND_FAILED, + RT_ERROR_TSFW_IPC_PROC_REG_FAILED, + RT_ERROR_TSFW_STREAM_FULL, + RT_ERROR_TSFW_END_OF_SEQUENCE, + RT_ERROR_TSFW_SWITCH_STREAM_LABEL, + RT_ERROR_TSFW_TRANS_SQE_FAIL, + RT_ERROR_TSFW_RESERVED, + + RT_ERROR_SUBSCRIBE_BASE = 0x07160000, + RT_ERROR_SUBSCRIBE_NULL, + RT_ERROR_SUBSCRIBE_NEW, + RT_ERROR_SUBSCRIBE_STREAM, + RT_ERROR_SUBSCRIBE_THREAD, + RT_ERROR_SUBSCRIBE_GROUP, + + RT_ERROR_GROUP_BASE = 0x07170000, + RT_ERROR_GROUP_NOT_SET, + RT_ERROR_GROUP_NOT_CREATE, + + RT_ERROR_RESERVED = 0x07ff0000, + }rtError_t; /** * @ingroup dvrt_base @@ -88,7 +339,8 @@ typedef enum tagRtExceptionType { RT_EXCEPTION_TS_DOWN = 1, RT_EXCEPTION_TASK_TIMEOUT = 2, RT_EXCEPTION_TASK_FAILURE = 3, - RT_EXCEPTION_DEV_RUNNING_DOWN = 4 + RT_EXCEPTION_DEV_RUNNING_DOWN = 4, + RT_EXCEPTION_STREAM_ID_FREE_FAILED = 5 } rtExceptionType; /** @@ -126,6 +378,7 @@ typedef struct rtExceptionInfo { uint32_t taskid; uint32_t streamid; uint32_t tid; + uint32_t deviceid; } rtExceptionInfo; typedef void (*rtErrorCallback)(rtExceptionType); @@ -225,7 +478,7 @@ typedef void *rtNotify_t; * @brief create label instance * @param [out] label created label * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelCreate(rtLabel_t *label); @@ -235,7 +488,7 @@ RTS_API rtError_t rtLabelCreate(rtLabel_t *label); * @param [in] label set label * @param [in] stream set stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelSet(rtLabel_t label, rtStream_t stream); @@ -244,7 +497,7 @@ RTS_API rtError_t rtLabelSet(rtLabel_t label, rtStream_t stream); * @brief destroy label instance * @param [in] label label to destroy * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelDestroy(rtLabel_t label); @@ -257,7 +510,7 @@ RTS_API rtError_t rtLabelDestroy(rtLabel_t label); * @param [in] true_label goto label * @param [in] stream to submit label_switch task * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelSwitch(void *ptr, rtCondition_t condition, uint32_t value, rtLabel_t trueLabel, rtStream_t stream); @@ -268,7 +521,7 @@ RTS_API rtError_t rtLabelSwitch(void *ptr, rtCondition_t condition, uint32_t val * @param [in] label goto label * @param [in] stream to submit label_goto task * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelGoto(rtLabel_t label, rtStream_t stream); @@ -278,7 +531,7 @@ RTS_API rtError_t rtLabelGoto(rtLabel_t label, rtStream_t stream); * @param [in] label instance * @param [in] name label name * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtNameLabel(rtLabel_t label, const char *name); @@ -290,7 +543,7 @@ RTS_API rtError_t rtNameLabel(rtLabel_t label, const char *name); * @param [in] labelInfoPtr label content info ptr * @param [in] stream set stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoPtr, rtStream_t stream); @@ -300,7 +553,7 @@ RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoP * @param [in] label goto label * @param [in] stream stream to submit label_goto task * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelGotoEx(rtLabel_t label, rtStream_t stream); @@ -312,7 +565,7 @@ RTS_API rtError_t rtLabelGotoEx(rtLabel_t label, rtStream_t stream); * @param [in] dst device ptr * @param [in] dstMax dst size * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *dst, uint32_t dstMax); @@ -322,7 +575,7 @@ RTS_API rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *d * @param [out] label created label handle * @param [in] stream label bind stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream); #ifdef __cplusplus diff --git a/third_party/fwkacllib/inc/runtime/context.h b/third_party/fwkacllib/inc/runtime/context.h index 70437b74..cc74a5ed 100644 --- a/third_party/fwkacllib/inc/runtime/context.h +++ b/third_party/fwkacllib/inc/runtime/context.h @@ -41,7 +41,7 @@ typedef enum tagCtxMode { typedef struct tagRtGroupInfo { int32_t groupId; - int32_t flag; + uint32_t flag; uint32_t aicoreNum; uint32_t aicpuNum; uint32_t aivectorNum; diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index f79f060c..bf2ce447 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -72,7 +72,7 @@ typedef enum tagMemcpyInfo { * @brief get total device number. * @param [in|out] count the device number * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetDeviceCount(int32_t *count); /** @@ -114,7 +114,7 @@ RTS_API rtError_t rtGetDeviceIDs(uint32_t *devices, uint32_t len); } DEV_INFO_TYPE; * @param [out] value the device info * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device + * @return RT_ERROR_DRV_ERR for error */ RTS_API rtError_t rtGetDeviceInfo(uint32_t deviceId, int32_t moduleType, int32_t infoType, int64_t *value); @@ -123,7 +123,7 @@ RTS_API rtError_t rtGetDeviceInfo(uint32_t deviceId, int32_t moduleType, int32_t * @brief set target device for current thread * @param [int] device the device id * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE for can not match ID and device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtSetDevice(int32_t device); @@ -132,7 +132,7 @@ RTS_API rtError_t rtSetDevice(int32_t device); * @brief set target device for current thread * @param [int] device the device id * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE for can not match ID and device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtSetDeviceEx(int32_t device); @@ -142,7 +142,7 @@ RTS_API rtError_t rtSetDeviceEx(int32_t device); * @param [in] phyId the physical device id * @param [out] devIndex the logic device id * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetDeviceIndexByPhyId(uint32_t phyId, uint32_t *devIndex); @@ -152,7 +152,7 @@ RTS_API rtError_t rtGetDeviceIndexByPhyId(uint32_t phyId, uint32_t *devIndex); * @param [in] devIndex the logic device id * @param [out] phyId the physical device id * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetDevicePhyIdByIndex(uint32_t devIndex, uint32_t *phyId); @@ -162,7 +162,7 @@ RTS_API rtError_t rtGetDevicePhyIdByIndex(uint32_t devIndex, uint32_t *phyId); * @param [in] devIdDes the logical device id * @param [in] phyIdSrc the physical device id * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtEnableP2P(uint32_t devIdDes, uint32_t phyIdSrc); @@ -172,7 +172,7 @@ RTS_API rtError_t rtEnableP2P(uint32_t devIdDes, uint32_t phyIdSrc); * @param [in] devIdDes the logical device id * @param [in] phyIdSrc the physical device id * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDisableP2P(uint32_t devIdDes, uint32_t phyIdSrc); @@ -183,7 +183,7 @@ RTS_API rtError_t rtDisableP2P(uint32_t devIdDes, uint32_t phyIdSrc); * @param [in] phyIdSrc the physical device id * @param [in|out] status status value * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetP2PStatus(uint32_t devIdDes, uint32_t phyIdSrc, uint32_t *status); @@ -208,7 +208,7 @@ RTS_API rtError_t rtGetDevice(int32_t *device); * @ingroup dvrt_dev * @brief reset all opened device * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE if no device set + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDeviceReset(int32_t device); @@ -216,7 +216,7 @@ RTS_API rtError_t rtDeviceReset(int32_t device); * @ingroup dvrt_dev * @brief reset opened device * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE if no device set + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDeviceResetEx(int32_t device); @@ -228,7 +228,7 @@ RTS_API rtError_t rtDeviceResetEx(int32_t device); * @param [in] value limit value * @param [out] info the device info * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDeviceSetLimit(int32_t device, rtLimitType_t type, uint32_t value); @@ -236,7 +236,7 @@ RTS_API rtError_t rtDeviceSetLimit(int32_t device, rtLimitType_t type, uint32_t * @ingroup dvrt_dev * @brief Wait for compute device to finish * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE if no device set + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDeviceSynchronize(void); diff --git a/third_party/fwkacllib/inc/runtime/dvfsprofile.h b/third_party/fwkacllib/inc/runtime/dvfsprofile.h index 11081546..60f400b3 100644 --- a/third_party/fwkacllib/inc/runtime/dvfsprofile.h +++ b/third_party/fwkacllib/inc/runtime/dvfsprofile.h @@ -36,8 +36,6 @@ typedef enum dvfsProfileMode { * @param [in] mode dvfsProfileMode * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_DEVICE for invalid device handle - * @return RT_ERROR_OPEN_FILE_FAILURE for invalid file handle */ RTS_API rtError_t rtSetDvfsProfile(DvfsProfileMode mode); @@ -46,7 +44,6 @@ RTS_API rtError_t rtSetDvfsProfile(DvfsProfileMode mode); * @brief Set the performance mode of the device * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for invalid value - * @return RT_ERROR_INVALID_DEVICE for invalid device handle */ RTS_API rtError_t rtUnsetDvfsProfile(); @@ -56,9 +53,6 @@ RTS_API rtError_t rtUnsetDvfsProfile(); * @param [in|out] pmode dvfsProfileMode type pointer * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_DEVICE for invalid device handle - * @return RT_ERROR_OPEN_FILE_FAILURE for invalid file handle - * @return RT_ERROR_NO_DEVICE for reading npu_freq_dnlimit failed */ RTS_API rtError_t rtGetDvfsProfile(DvfsProfileMode *pmode); diff --git a/third_party/fwkacllib/inc/runtime/event.h b/third_party/fwkacllib/inc/runtime/event.h index 31991cf7..9dc44766 100644 --- a/third_party/fwkacllib/inc/runtime/event.h +++ b/third_party/fwkacllib/inc/runtime/event.h @@ -35,7 +35,7 @@ extern "C" { * @brief create event instance * @param [in|out] event created event * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtEventCreate(rtEvent_t *event); @@ -44,7 +44,7 @@ RTS_API rtError_t rtEventCreate(rtEvent_t *event); * @brief create event instance with flag * @param [in|out] event created event flag event op flag * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag); @@ -53,7 +53,7 @@ RTS_API rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag); * @brief destroy event instance * @param [in] event event to destroy * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtEventDestroy(rtEvent_t event); @@ -63,7 +63,7 @@ RTS_API rtError_t rtEventDestroy(rtEvent_t event); * @param [int] event event to record * @param [int] stream stream handle * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtEventRecord(rtEvent_t event, rtStream_t stream); @@ -81,7 +81,7 @@ RTS_API rtError_t rtEventReset(rtEvent_t event, rtStream_t stream); * @brief wait event to be complete * @param [in] event event to wait * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtEventSynchronize(rtEvent_t event); @@ -90,7 +90,7 @@ RTS_API rtError_t rtEventSynchronize(rtEvent_t event); * @brief Queries an event's status * @param [in] event event to query * @return RT_ERROR_NONE for complete - * @return RT_ERROR_NOT_READY for not complete + * @return RT_ERROR_EVENT_NOT_COMPLETE for not complete */ RTS_API rtError_t rtEventQuery(rtEvent_t event); @@ -131,7 +131,6 @@ RTS_API rtError_t rtNameEvent(rtEvent_t event_, const char *name); * @param [in|out] notify_ notify to be created * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle */ RTS_API rtError_t rtNotifyCreate(int32_t device_id, rtNotify_t *notify_); @@ -152,8 +151,7 @@ RTS_API rtError_t rtNotifyDestroy(rtNotify_t notify_); * @param [in] stream_ input stream * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle - * @return RT_ERROR_INVALID_DEVICE for stream is not in current ctx + * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx */ RTS_API rtError_t rtNotifyRecord(rtNotify_t notify_, rtStream_t stream_); @@ -164,8 +162,7 @@ RTS_API rtError_t rtNotifyRecord(rtNotify_t notify_, rtStream_t stream_); * @param [in] stream_ input stream * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle - * @return RT_ERROR_INVALID_DEVICE for stream is not in current ctx + * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx */ RTS_API rtError_t rtNotifyWait(rtNotify_t notify_, rtStream_t stream_); @@ -206,7 +203,6 @@ RTS_API rtError_t rtIpcSetNotifyName(rtNotify_t notify, char *name, uint32_t len * @param [in] name identification name * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle */ RTS_API rtError_t rtIpcOpenNotify(rtNotify_t *notify, const char *name); diff --git a/third_party/fwkacllib/inc/runtime/kernel.h b/third_party/fwkacllib/inc/runtime/kernel.h index c99eb96f..aec290da 100644 --- a/third_party/fwkacllib/inc/runtime/kernel.h +++ b/third_party/fwkacllib/inc/runtime/kernel.h @@ -176,6 +176,13 @@ typedef void (*rtCallback_t)(void *fnData); #define RT_KERNEL_DEFAULT (0x00) #define RT_KERNEL_CONVERT (0x01) #define RT_KERNEL_DUMPFLAG (0x02) +#define RT_FUSION_KERNEL_DUMPFLAG (0x04) + +/** + * @ingroup rt_kernel + * @brief kernel L1 Fusion Dump bit flags + */ +#define RT_DDR_ADDR (0x0) /** * @ingroup rt_kernel @@ -183,7 +190,7 @@ typedef void (*rtCallback_t)(void *fnData); * @param [in] bin device binary description * @param [out] handle device binary handle * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **handle); @@ -192,7 +199,7 @@ RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **handle); * @brief register fast memeory device binary * @param [in] handle device binary handle * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtBinaryRegisterToFastMemory(void *handle); @@ -201,7 +208,7 @@ RTS_API rtError_t rtBinaryRegisterToFastMemory(void *handle); * @brief unregister device binary * @param [in] handle device binary handle * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDevBinaryUnRegister(void *handle); @@ -211,7 +218,7 @@ RTS_API rtError_t rtDevBinaryUnRegister(void *handle); * @param [in] handle device binary description * @param [in] metadata device binary metadata * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMetadataRegister(void *handle, const char *metadata); @@ -221,7 +228,7 @@ RTS_API rtError_t rtMetadataRegister(void *handle, const char *metadata); * @param [in] mHandle master device binary description * @param [in] sHandle slave device binary description * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDependencyRegister(void *mHandle, void *sHandle); @@ -234,7 +241,7 @@ RTS_API rtError_t rtDependencyRegister(void *mHandle, void *sHandle); * @param [in] devFunc device function description. symbol name or address * offset, depending binary type. * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtFunctionRegister(void *binHandle, const void *stubFunc, const char *stubName, const void *devFunc, uint32_t funcMode); @@ -245,6 +252,7 @@ RTS_API rtError_t rtFunctionRegister(void *binHandle, const void *stubFunc, cons * @param [in] stubName stub function name * @param [out] stubFunc stub function * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetFunctionByName(const char *stubName, void **stubFunc); @@ -254,6 +262,7 @@ RTS_API rtError_t rtGetFunctionByName(const char *stubName, void **stubFunc); * @param [in] stubFunc stub function * @param [out] addr * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetAddrByFun(const void *stubFunc, void **addr); /** @@ -261,6 +270,7 @@ RTS_API rtError_t rtGetAddrByFun(const void *stubFunc, void **addr); * @brief query registered or not by stubName * @param [in] stubName stub function name * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtQueryFunctionRegistered(const char *stubName); @@ -270,7 +280,8 @@ RTS_API rtError_t rtQueryFunctionRegistered(const char *stubName); * @param [in] dumpSizePerBlock dump size * @param [in] blockDim block dimentions * @param [in] dumpBaseAddr dump base address - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelConfigDump(uint32_t kind, uint32_t dumpSizePerBlock, uint32_t blockDim, void **dumpBaseAddr, rtStream_t stream_); @@ -284,7 +295,8 @@ RTS_API rtError_t rtKernelConfigDump(uint32_t kind, uint32_t dumpSizePerBlock, u * @param [in] argsSize argements size * @param [in] smDesc shared memory description * @param [in] stream associated stream - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); @@ -299,7 +311,8 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void * * @param [in] smDesc shared memory description * @param [in] stream associated stream * @param [in] flag dump flag - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); @@ -311,7 +324,8 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim * @param [in] argsSize argements size * @param [in] flags launch flags * @param [in] stream associated stream - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream); @@ -325,7 +339,8 @@ RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags * @param [in] argsSize argments size * @param [in] smDesc shared memory description * @param [in] stream associated stream - * @retval RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtCpuKernelLaunch(const void *soName, const void *kernelName, uint32_t blockDim, const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); @@ -341,18 +356,33 @@ RTS_API rtError_t rtCpuKernelLaunch(const void *soName, const void *kernelName, * @param [in] smDesc shared memory description * @param [in] stream associated stream * @param [in] flag dump flag or others function flag - * @retval RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kernelName, uint32_t blockDim, const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); +typedef void *rtModel_t; +/** + * @ingroup rt_kernel + * @brief L1 fusion dump addr transfered to device + * @param [in] model handle info + * @param [in] addr ddr address of L1 Fusion Dump + * @param [in] dumpSize memory size + * @param [in] flag memory flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ + RTS_API rtError_t rtDumpAddrSet(rtModel_t model, void *addr , uint32_t dumpSize, uint32_t flag); + /** * @ingroup rt_kernel * @brief load dump info to aicpu * @param [in] dumpInfo dump info * @param [in] length length of dump info - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDatadumpInfoLoad(const void *dumpInfo, uint32_t length); @@ -365,7 +395,7 @@ RTS_API rtError_t rtDatadumpInfoLoad(const void *dumpInfo, uint32_t length); * @param [in] smDesc shared memory description * @param [in] stream associated stream * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ #ifdef __cplusplus RTS_API rtError_t rtConfigureCall(uint32_t numBlocks, rtSmDesc_t *smDesc = nullptr, rtStream_t stream = nullptr); @@ -381,7 +411,7 @@ RTS_API rtError_t rtConfigureCall(uint32_t numBlocks, rtSmDesc_t *smDesc, rtStre * @param [in] size argment size * @param [in] offset argment table offset * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtSetupArgument(const void *arg, uint32_t size, uint32_t offset); @@ -391,7 +421,7 @@ RTS_API rtError_t rtSetupArgument(const void *arg, uint32_t size, uint32_t offse * and call argment * @param [in] stubFunc stub function * @return RT_ERROR_NONE for ok - * @note:if this interface is changed, pls notify the compiler changing at the same time. + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtLaunch(const void *stubFunc); @@ -404,6 +434,7 @@ RTS_API rtError_t rtLaunch(const void *stubFunc); * @param [in] flag reserved. set to 0 * @param [out] arg returned arg. used for next kernel's arg. * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelConfigTransArg(const void *ptr, uint64_t size, uint32_t flag, void **arg); @@ -411,7 +442,8 @@ RTS_API rtError_t rtKernelConfigTransArg(const void *ptr, uint64_t size, uint32_ * @ingroup rt_kernel * @brief start fusion kernels. * @param [in] stream stream for fusion kernels - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelFusionStart(rtStream_t stream); @@ -419,7 +451,8 @@ RTS_API rtError_t rtKernelFusionStart(rtStream_t stream); * @ingroup rt_kernel * @brief end fusion kernels. * @param [in] stream stream for fusion kernels - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelFusionEnd(rtStream_t stream); @@ -427,7 +460,8 @@ RTS_API rtError_t rtKernelFusionEnd(rtStream_t stream); * @ingroup rt_kernel * @brief set kernelinfo callback * @param [in] callback - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtSetKernelReportCallback(rtKernelReportCallback callBack); @@ -436,7 +470,8 @@ RTS_API rtError_t rtSetKernelReportCallback(rtKernelReportCallback callBack); * @brief subscribe stream callback report. * @param [in] threadId thread id for stream * @param [in] stream stream for subscribe - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtSubscribeReport(uint64_t threadId, rtStream_t stream); @@ -446,7 +481,8 @@ RTS_API rtError_t rtSubscribeReport(uint64_t threadId, rtStream_t stream); * @param [in] callBackFunc app callback function * @param [in] fnData user data * @param [in] stream subscribed stream - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtCallbackLaunch(rtCallback_t callBackFunc, void *fnData, rtStream_t stream, bool isBlock); @@ -454,7 +490,8 @@ RTS_API rtError_t rtCallbackLaunch(rtCallback_t callBackFunc, void *fnData, rtSt * @ingroup rt_kernel * @brief process callback report. * @param [in] timeout if timeout=-1, while(1); else timeout - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtProcessReport(int32_t timeout); @@ -463,25 +500,32 @@ RTS_API rtError_t rtProcessReport(int32_t timeout); * @brief unsubscribe callback report. * @param [in] threadId thread id for stream * @param [in] stream stream for subscribe - * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtUnSubscribeReport(uint64_t threadId, rtStream_t stream); /** * @ingroup profiling_base * @brief start online prof. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtStartOnlineProf(rtStream_t stream, uint32_t sampleNum); /** * @ingroup profiling_base * @brief stop online prof. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtStopOnlineProf(rtStream_t stream); /** * @ingroup profiling_base * @brief get online prof. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetOnlineProfData(rtStream_t stream, rtProfDataInfo_t *pProfData, uint32_t profDataNum); #ifdef __cplusplus diff --git a/third_party/fwkacllib/inc/runtime/mem.h b/third_party/fwkacllib/inc/runtime/mem.h index e70ebd38..3280f3c6 100644 --- a/third_party/fwkacllib/inc/runtime/mem.h +++ b/third_party/fwkacllib/inc/runtime/mem.h @@ -154,7 +154,7 @@ typedef struct tagRtPointerAttributes { * @param [in] size memory size * @param [in] type memory type * @return RT_ERROR_NONE for ok - * @return RT_ERROR_MEMORY_ALLOCATION for memory allocation failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMalloc(void **devPtr, uint64_t size, rtMemType_t type); @@ -163,7 +163,7 @@ RTS_API rtError_t rtMalloc(void **devPtr, uint64_t size, rtMemType_t type); * @brief free device memory * @param [in|out] devPtr memory pointer * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE_POINTER for error device memory pointer + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtFree(void *devPtr); @@ -173,7 +173,7 @@ RTS_API rtError_t rtFree(void *devPtr); * @param [in|out] devPtr memory pointer * @param [in] size memory size * @return RT_ERROR_NONE for ok - * @return RT_ERROR_MEMORY_ALLOCATION for memory allocation failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDvppMalloc(void **devPtr, uint64_t size); @@ -182,7 +182,7 @@ RTS_API rtError_t rtDvppMalloc(void **devPtr, uint64_t size); * @brief free device memory for dvpp * @param [in|out] devPtr memory pointer * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE_POINTER for error device memory pointer + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDvppFree(void *devPtr); @@ -192,7 +192,7 @@ RTS_API rtError_t rtDvppFree(void *devPtr); * @param [in|out] hostPtr memory pointer * @param [in] size memory size * @return RT_ERROR_NONE for ok - * @return RT_ERROR_MEMORY_ALLOCATION for memory allocation failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMallocHost(void **hostPtr, uint64_t size); @@ -201,7 +201,7 @@ RTS_API rtError_t rtMallocHost(void **hostPtr, uint64_t size); * @brief free host memory * @param [in] hostPtr memory pointer * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE_POINTER for error device memory pointer + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtFreeHost(void *hostPtr); @@ -212,7 +212,7 @@ RTS_API rtError_t rtFreeHost(void *hostPtr); * @param [in] size memory size * @param [in] flag reserved, set to 0. * @return RT_ERROR_NONE for ok - * @return RT_ERROR_MEMORY_ALLOCATION for memory allocation failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag); @@ -221,7 +221,7 @@ RTS_API rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag); * @brief free managed memory * @param [in] ptr memory pointer * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE_POINTER for error device memory pointer + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMemFreeManaged(void *ptr); /** @@ -261,9 +261,7 @@ RTS_API rtError_t rtInvalidCache(void *base, size_t len); * @param [in] count the number of byte to copy * @param [in] kind memcpy type * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of count - * @return RT_ERROR_INVALID_DEVICE_POINTER for error input memory pointer of dst,src - * @return RT_ERROR_INVALID_MEMCPY_DIRECTION for error copy direction of kind + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMemcpy(void *dst, uint64_t destMax, const void *src, uint64_t count, rtMemcpyKind_t kind); @@ -277,9 +275,7 @@ RTS_API rtError_t rtMemcpy(void *dst, uint64_t destMax, const void *src, uint64_ * @param [in] kind memcpy type * @param [in] stream asynchronized task stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of count,stream - * @return RT_ERROR_INVALID_DEVICE_POINTER for error input memory pointer of dst,src - * @return RT_ERROR_INVALID_MEMCPY_DIRECTION for error copy direction of kind + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMemcpyAsync(void *dst, uint64_t destMax, const void *src, uint64_t count, rtMemcpyKind_t kind, rtStream_t stream); @@ -295,9 +291,7 @@ RTS_API rtError_t rtMemcpyAsync(void *dst, uint64_t destMax, const void *src, ui * @param [in] type data type * @param [in] stream asynchronized task stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of count,stream - * @return RT_ERROR_INVALID_DEVICE_POINTER for error input memory pointer of dst,src - * @return RT_ERROR_INVALID_MEMCPY_DIRECTION for error copy direction of kind + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtReduceAsync(void *dst, uint64_t destMax, const void *src, uint64_t count, rtRecudeKind_t kind, rtDataType_t type, rtStream_t stream); @@ -307,6 +301,7 @@ RTS_API rtError_t rtReduceAsync(void *dst, uint64_t destMax, const void *src, ui * @brief query memory size * @param [in] aiCoreMemorySize * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtAiCoreMemorySizes(rtAiCoreMemorySize_t *aiCoreMemorySize); @@ -316,6 +311,7 @@ RTS_API rtError_t rtAiCoreMemorySizes(rtAiCoreMemorySize_t *aiCoreMemorySize); integrated network due to memory limitations.Requirement come from JiaMinHu.Only use for Tiny. * @param [in] aiCoreMemorySize * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtSetAiCoreMemorySizes(rtAiCoreMemorySize_t *aiCoreMemorySize); @@ -327,6 +323,7 @@ RTS_API rtError_t rtSetAiCoreMemorySizes(rtAiCoreMemorySize_t *aiCoreMemorySize) * @param [in] value * @param [in] count byte num * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMemset(void *devPtr, uint64_t destMax, uint32_t value, uint64_t count); @@ -339,6 +336,7 @@ RTS_API rtError_t rtMemset(void *devPtr, uint64_t destMax, uint32_t value, uint6 * @param [in] count byte num * @param [in] stream * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream); @@ -348,6 +346,7 @@ RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uin * @param [out] free * @param [out] total * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMemGetInfo(size_t *free, size_t *total); @@ -358,6 +357,7 @@ RTS_API rtError_t rtMemGetInfo(size_t *free, size_t *total); * @param [in] len * @param [in] device * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtMemPrefetchToDevice(void *devPtr, uint64_t len, int32_t device); @@ -367,6 +367,7 @@ RTS_API rtError_t rtMemPrefetchToDevice(void *devPtr, uint64_t len, int32_t devi * @param [in] ptr * @param [out] attributes * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtPointerGetAttributes(rtPointerAttributes_t *attributes, const void *ptr); @@ -377,7 +378,7 @@ RTS_API rtError_t rtPointerGetAttributes(rtPointerAttributes_t *attributes, cons * @param [in] name identification name * @param [in] byteCount identification byteCount * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of ptr, name, byteCount + * @return RT_ERROR_INVALID_VALUE for error input * @return RT_ERROR_DRV_ERR for driver error */ RTS_API rtError_t rtIpcSetMemoryName(const void *ptr, uint64_t byteCount, char *name, uint32_t len); @@ -387,7 +388,7 @@ RTS_API rtError_t rtIpcSetMemoryName(const void *ptr, uint64_t byteCount, char * * @brief destroy a interprocess shared memory * @param [in] name identification name * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of name + * @return RT_ERROR_INVALID_VALUE for error input * @return RT_ERROR_DRV_ERR for driver error */ rtError_t rtIpcDestroyMemoryName(const char *name); @@ -398,7 +399,7 @@ rtError_t rtIpcDestroyMemoryName(const char *name); * @param [in|out] ptr device memory address pointer * @param [in] name identification name * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of ptr, name + * @return RT_ERROR_INVALID_VALUE for error input * @return RT_ERROR_DRV_ERR for driver error */ RTS_API rtError_t rtIpcOpenMemory(void **ptr, const char *name); @@ -409,7 +410,7 @@ RTS_API rtError_t rtIpcOpenMemory(void **ptr, const char *name); * @param [in] ptr device memory address pointer * @param [in] name identification name * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of ptr, name + * @return RT_ERROR_INVALID_VALUE for error input * @return RT_ERROR_DRV_ERR for driver error */ RTS_API rtError_t rtIpcCloseMemory(const void *ptr); @@ -421,7 +422,7 @@ RTS_API rtError_t rtIpcCloseMemory(const void *ptr); * @param [in] wqe_index moudle index * @param [in] stream asynchronized task stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of ptr, name + * @return RT_ERROR_INVALID_VALUE for error input * @return RT_ERROR_DRV_ERR for driver error */ RTS_API rtError_t rtRDMASend(uint32_t index, uint32_t wqe_index, rtStream_t stream); @@ -434,7 +435,6 @@ RTS_API rtError_t rtRDMASend(uint32_t index, uint32_t wqe_index, rtStream_t stre * @param [in] num length of pid[] * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle * @return RT_ERROR_DRV_ERR for driver error */ RTS_API rtError_t rtSetIpcMemPid(const char *name, int32_t pid[], int num); @@ -446,7 +446,7 @@ RTS_API rtError_t rtSetIpcMemPid(const char *name, int32_t pid[], int num); * @param [in] dbinfo doorbell info * @param [in] stream asynchronized task stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input of ptr, name + * @return RT_ERROR_INVALID_VALUE for error input * @return RT_ERROR_DRV_ERR for driver error */ RTS_API rtError_t rtRDMADBSend(uint32_t dbIndex, uint64_t dbInfo, rtStream_t stream); diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index 5c85a3d7..089a90b7 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -277,7 +277,7 @@ typedef rtError_t (*rtTaskGenCallback)(rtModel_t model, rtTaskInfo_t *taskInfo); * @brief set callback for generate model * @param [in] callBack callback function * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtSetTaskGenCallback(rtTaskGenCallback callback); @@ -287,7 +287,7 @@ RTS_API rtError_t rtSetTaskGenCallback(rtTaskGenCallback callback); * @param [out] model created model * @param [in] flag reserved * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelCreate(rtModel_t *model, uint32_t flag); @@ -296,7 +296,7 @@ RTS_API rtError_t rtModelCreate(rtModel_t *model, uint32_t flag); * @brief destroy model instance * @param [in] model model to destroy * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelDestroy(rtModel_t model); @@ -307,7 +307,7 @@ RTS_API rtError_t rtModelDestroy(rtModel_t model); * @param [in] stream binded stream * @param [in] flag reserved * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelBindStream(rtModel_t model, rtStream_t stream, uint32_t flag); @@ -317,7 +317,7 @@ RTS_API rtError_t rtModelBindStream(rtModel_t model, rtStream_t stream, uint32_t * @param [in] model unbinded model * @param [in] stream unbinded stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelUnbindStream(rtModel_t model, rtStream_t stream); @@ -334,7 +334,7 @@ RTS_API rtError_t rtModelLoadComplete(rtModel_t model); * @brief execute model instance * @param [in] model model to execute * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelExecute(rtModel_t model, rtStream_t stream, uint32_t flag); @@ -345,7 +345,7 @@ RTS_API rtError_t rtModelExecute(rtModel_t model, rtStream_t stream, uint32_t fl * @param [out] taskid last task id of the model * @param [out] streamid last steam id of the model * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid, uint32_t *streamid); @@ -355,7 +355,7 @@ RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid, uint32_t * * @param [in] model model to execute * @param [in] end graph stream * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtEndGraph(rtModel_t model, rtStream_t stream); @@ -366,7 +366,7 @@ RTS_API rtError_t rtEndGraph(rtModel_t model, rtStream_t stream); * @param [in] end graph stream * @param [in] flags AICPU datadump * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtEndGraphEx(rtModel_t model, rtStream_t stream, uint32_t flags); @@ -376,7 +376,7 @@ RTS_API rtError_t rtEndGraphEx(rtModel_t model, rtStream_t stream, uint32_t flag * @param [in] model model to execute * @param [in] flags EXECUTOR_TS | EXECUTOR_AICPU * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelExecutorSet(rtModel_t model, uint8_t flags); @@ -385,7 +385,7 @@ RTS_API rtError_t rtModelExecutorSet(rtModel_t model, uint8_t flags); * @brief abort model * @param [in] model model to abort * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelAbort(rtModel_t model); @@ -396,7 +396,7 @@ RTS_API rtError_t rtModelAbort(rtModel_t model); * @param [in] queueId queueId to bind * @param [in] flag * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelBindQueue(rtModel_t model, uint32_t queueId, rtModelQueueFlag_t flag); @@ -406,7 +406,7 @@ RTS_API rtError_t rtModelBindQueue(rtModel_t model, uint32_t queueId, rtModelQue * @param [in] model * @param [out] modelId model id * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtModelGetId(rtModel_t model, uint32_t *modelId); @@ -417,7 +417,7 @@ RTS_API rtError_t rtModelGetId(rtModel_t model, uint32_t *modelId); * @param [in] model: model handle * @param [in] flag: debug flag * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId); @@ -426,7 +426,7 @@ rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *addr, uint * @brief disable debug for dump overflow exception * @param [in] model: model handle * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtDebugUnRegister(rtModel_t model); diff --git a/third_party/fwkacllib/inc/runtime/stream.h b/third_party/fwkacllib/inc/runtime/stream.h index 232b5169..3123c3a9 100644 --- a/third_party/fwkacllib/inc/runtime/stream.h +++ b/third_party/fwkacllib/inc/runtime/stream.h @@ -35,6 +35,7 @@ extern "C" { #define RT_STREAM_AICPU (0x08) #define RT_STREAM_FORBIDDEN_DEFAULT (0x10) #define RT_STREAM_HEAD (0x20) +#define RT_STREAM_PRIMARY_DEFAULT (0x40) /** * @ingroup stream_type @@ -54,8 +55,7 @@ extern "C" { * @param [in|out] stream created stream * @param [in] priority stream priority * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream handle - * @return RT_ERROR_INVALID_VALUE for error input priority + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtStreamCreate(rtStream_t *stream, int32_t priority); @@ -66,8 +66,7 @@ RTS_API rtError_t rtStreamCreate(rtStream_t *stream, int32_t priority); * @param [in] priority stream priority * @param [in] flags stream op flags * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream handle - * @return RT_ERROR_INVALID_VALUE for error input priority + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtStreamCreateWithFlags(rtStream_t *stream, int32_t priority, uint32_t flags); @@ -76,7 +75,7 @@ RTS_API rtError_t rtStreamCreateWithFlags(rtStream_t *stream, int32_t priority, * @brief destroy stream instance. * @param [in] stream the stream to destroy * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtStreamDestroy(rtStream_t stream); @@ -86,7 +85,7 @@ RTS_API rtError_t rtStreamDestroy(rtStream_t stream); * @param [in] stream the wait stream * @param [in] event the event to wait * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream or event handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtStreamWaitEvent(rtStream_t stream, rtEvent_t event); @@ -95,7 +94,7 @@ RTS_API rtError_t rtStreamWaitEvent(rtStream_t stream, rtEvent_t event); * @brief wait stream to be complete * @param [in] stream stream to wait * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream or event handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtStreamSynchronize(rtStream_t stream); @@ -104,7 +103,7 @@ RTS_API rtError_t rtStreamSynchronize(rtStream_t stream); * @brief queries an asynchronous stream for completion status * @param [in] stream stream to query * @return RT_ERROR_NONE for complete - * @return RT_ERROR_NOT_READY for not complete + * @return RT_ERROR_STREAM_NOT_COMPLETE for not complete */ RTS_API rtError_t rtStreamQuery(rtStream_t stream); @@ -114,7 +113,7 @@ RTS_API rtError_t rtStreamQuery(rtStream_t stream); * @param [in] stream stream hadle * @param [in] streamId stream id * @return RT_ERROR_NONE for complete - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetStreamId(rtStream_t stream, int32_t *streamId); @@ -125,7 +124,7 @@ RTS_API rtError_t rtGetStreamId(rtStream_t stream, int32_t *streamId); * @param [in] MaxStrCount Max stream count * @param [in] MaxTaskCount max task count per stream * @return RT_ERROR_NONE for complete - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream handle + * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtGetMaxStreamAndTask(uint32_t streamType, uint32_t *MaxStrCount, uint32_t *MaxTaskCount); @@ -136,7 +135,6 @@ RTS_API rtError_t rtGetMaxStreamAndTask(uint32_t streamType, uint32_t *MaxStrCou * @param [in] name identification name * @return RT_ERROR_NONE for complete * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle */ RTS_API rtError_t rtNameStream(rtStream_t stream_, const char *name); @@ -150,9 +148,6 @@ RTS_API rtError_t rtNameStream(rtStream_t stream_, const char *name); * @param [in] stream input stream to init task * @return RT_ERROR_NONE for complete * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle - * @return RT_ERROR_INVALID_DEVICE for invalid device handle - * @return ERROR_RECYCLE for switching task init failed or submit failed */ RTS_API rtError_t rtStreamSwitch(void *ptr, rtCondition_t condition, int64_t value, rtStream_t true_stream, rtStream_t stream); @@ -166,7 +161,6 @@ RTS_API rtError_t rtStreamSwitch(void *ptr, rtCondition_t condition, int64_t val * @param [in] stream stream id * @param [in] dataType data type of target value * @return RT_ERROR_NONE for complete - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for not complete */ RTS_API rtError_t rtStreamSwitchEx(void *ptr, rtCondition_t condition, void *value_ptr, rtStream_t true_stream, rtStream_t stream, rtSwitchDataType_t dataType); @@ -178,9 +172,6 @@ RTS_API rtError_t rtStreamSwitchEx(void *ptr, rtCondition_t condition, void *val * @param [in] stream input stream to init task * @return RT_ERROR_NONE for complete * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle - * @return RT_ERROR_INVALID_DEVICE for invalid device handle - * @return ERROR_RECYCLE for switching task init failed or submit failed */ RTS_API rtError_t rtStreamActive(rtStream_t active_stream, rtStream_t stream); @@ -194,7 +185,6 @@ RTS_API rtError_t rtStreamActive(rtStream_t active_stream, rtStream_t stream); * @param [in] stream input stream to init task * @param [in] dataType data type of target value * @return RT_ERROR_NONE for complete - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for not complete */ RTS_API rtError_t rtStreamSwitchN(void *ptr, uint32_t size, void *valuePtr, rtStream_t *trueStreamPtr, uint32_t elementSize, rtStream_t stream, rtSwitchDataType_t dataType); diff --git a/third_party/fwkacllib/inc/tdt/tsd_client.h b/third_party/fwkacllib/inc/tdt/tsd_client.h index 6aaca646..7886488e 100644 --- a/third_party/fwkacllib/inc/tdt/tsd_client.h +++ b/third_party/fwkacllib/inc/tdt/tsd_client.h @@ -151,84 +151,6 @@ TDT_StatusT GetCmdParameterObjAttribute(tdt::TsdCmdType type, void *cmdParameter */ TDT_StatusT TsdClientCmd(tdt::TsdCmdType cmd, void *cmdParameterObj); -namespace tdt { -/** -* @ingroup RANK_SIZE_DEFAULT_VALUE。 -* The default value of Rank size is 1 by default. -* It does not pull up HCCP to perform set communication related operations. -* -*/ -constexpr uint32_t RANK_SIZE_DEFAULT_VALUE = 1; - -class TsdClient { - public: - /** - * @ingroup GetInstance - * @brief Get TsdClient instance - * - * @par Function - * Get TsdClient instance - * - * @param NA - * @retval TsdClient TsdClient instance - * - * @par Dependency - * @li libtsdclient.so: Library to which the interface belongs. - * @li tsd_client.h: Header file where the interface declaration is located. - */ - static TsdClient *GetInstance(); - - /** - * @ingroup ~TsdClient - * @brief TsdClient destructor - * - * @par Function - * TsdClient destructor - * - * @param NA - * @retval NA - * - * @par Dependency - * @li libtsdclient.so: Library to which the interface belongs. - * @li tsd_client.h: Header file where the interface declaration is located. - */ - ~TsdClient(); - - /** - * @ingroup Open - * @brief Used for the Framework process to communicate with the TSDDaemon process, - * and notify TSD to complete the initialization of other processes - * - * @par Function - * Used for the Framework process to communicate with the TSDDaemon process, - * and notify TSD to complete the initialization of other processes - * - * @param phyDeviceId [IN] type #unsigned int. Physical device ID - * @param rankSize [IN] type #unsigned int. The rankSize of the training. - * The default value is 1. When rankSize is greater than 1, - * HCCP will be pulled to perform set communication related operations. - * @retval TDT_OK Success - * @retval OtherValues Failure - * - * @par Dependency - * @li libtsdclient.so: Library to which the interface belongs. - * @li tsd_client.h: Header file where the interface declaration is located. - * @li data_common.h: Header file where 'TDT_StatusT' defined - */ - TDT_StatusT Open(const uint32_t phyDeviceId, const uint32_t rankSize = RANK_SIZE_DEFAULT_VALUE); - - TDT_StatusT Close(); - - private: - TsdClient(); - TsdClient(const TsdClient &) = delete; - TsdClient(TsdClient &&) = delete; - TsdClient &operator=(const TsdClient &) = delete; - TsdClient &operator=(TsdClient &&) = delete; - - uint32_t rankSize_; -}; -} // namespace tdt #ifdef __cplusplus } #endif // __cplusplus diff --git a/third_party/fwkacllib/inc/toolchain/slog.h b/third_party/fwkacllib/inc/toolchain/slog.h index 261fe866..2cb00a05 100644 --- a/third_party/fwkacllib/inc/toolchain/slog.h +++ b/third_party/fwkacllib/inc/toolchain/slog.h @@ -175,6 +175,8 @@ enum { AIVECTOR, TBE, FV, + MDCMAP, + TUNE, INVLID_MOUDLE_ID }; @@ -209,6 +211,7 @@ extern int dlog_setlevel(int moduleId, int level, int enableEvent); /** * @ingroup slog * @brief CheckLogLevel: check module level enable or not + * users no need to call it because all dlog interface(include inner interface) has already called * * @param [in]moduleId: module id, eg: CCE * @param [in]logLevel: eg: DLOG_EVENT/DLOG_ERROR/DLOG_WARN/DLOG_INFO/DLOG_DEBUG @@ -231,37 +234,46 @@ extern int CheckLogLevel(int moduleId, int logLevel); /** * @ingroup slog * @brief dlog_warn: print warning log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time * * @param [in]moduleId: module id, eg: CCE * @param [in]fmt: log content */ -#define dlog_warn(moduleId, fmt, ...) \ - do { \ - DlogWarnInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ +#define dlog_warn(moduleId, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, DLOG_WARN) == 1) { \ + DlogWarnInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ } while (0) /** * @ingroup slog * @brief dlog_info: print info log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time * * @param [in]moduleId: module id, eg: CCE * @param [in]fmt: log content */ -#define dlog_info(moduleId, fmt, ...) \ - do { \ - DlogInfoInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ +#define dlog_info(moduleId, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, DLOG_INFO) == 1) { \ + DlogInfoInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ } while (0) /** * @ingroup slog * @brief dlog_debug: print debug log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time * * @param [in]moduleId: module id, eg: CCE * @param [in]fmt: log content */ -#define dlog_debug(moduleId, fmt, ...) \ - do { \ - DlogDebugInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ +#define dlog_debug(moduleId, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, DLOG_DEBUG) == 1) { \ + DlogDebugInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ } while (0) /** @@ -279,33 +291,40 @@ extern int CheckLogLevel(int moduleId, int logLevel); /** * @ingroup slog * @brief Dlog: print log, need caller to specify level + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time * * @param [in]moduleId: module id, eg: CCE * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) * @param [in]fmt: log content */ -#define Dlog(moduleId, level, fmt, ...) \ - do { \ - DlogInner(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ +#define Dlog(moduleId, level, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, level) == 1) { \ + DlogInner(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ } while (0) /** * @ingroup slog * @brief DlogSub: print log, need caller to specify level and submodule + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time * * @param [in]moduleId: module id, eg: CCE * @param [in]submodule: eg: engine * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) * @param [in]fmt: log content */ -#define DlogSub(moduleId, submodule, level, fmt, ...) \ - do { \ - DlogInner(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ +#define DlogSub(moduleId, submodule, level, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, level) == 1) { \ + DlogInner(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ + } \ } while (0) /** * @ingroup slog * @brief DlogWithKV: print log, need caller to specify level and other paramters + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time * * @param [in]moduleId: module id, eg: CCE * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) @@ -313,9 +332,11 @@ extern int CheckLogLevel(int moduleId, int logLevel); * @param [in]kvNum: key-value element num in array * @param [in]fmt: log content */ -#define DlogWithKV(moduleId, level, pstKVArray, kvNum, fmt, ...) \ - do { \ - DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ +#define DlogWithKV(moduleId, level, pstKVArray, kvNum, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, level) == 1) { \ + DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ } while (0) /**