Browse Source

Merge branch 'master' of gitee.com:mindspore/graphengine

pull/1907/head^2
zhaoxinxin 4 years ago
parent
commit
6a810c9024
92 changed files with 1583 additions and 726 deletions
  1. +1
    -2
      .clang-format
  2. +2
    -0
      ge/common/CMakeLists.txt
  3. +1
    -1
      ge/common/dump/dump_properties.cc
  4. +1
    -0
      ge/common/dump/exception_dumper.cc
  5. +4
    -0
      ge/executor/CMakeLists.txt
  6. +1
    -1
      ge/ge_runtime/task/label_goto_task.cc
  7. +1
    -2
      ge/graph/build/model_builder.cc
  8. +20
    -24
      ge/graph/build/task_generator.cc
  9. +1
    -1
      ge/graph/build/task_generator.h
  10. +2
    -0
      ge/graph/load/model_manager/task_info/kernel_task_info.cc
  11. +0
    -49
      ge/graph/manager/graph_var_manager.cc
  12. +0
    -15
      ge/graph/manager/graph_var_manager.h
  13. +0
    -66
      ge/graph/manager/trans_var_data_utils.cc
  14. +0
    -11
      ge/graph/manager/trans_var_data_utils.h
  15. +25
    -5
      ge/graph/passes/infer_value_range_pass.cc
  16. +1
    -0
      ge/graph/preprocess/insert_op/util_insert_aipp_op.cc
  17. +1
    -1
      ge/graph/preprocess/multi_batch_copy_graph.cc
  18. +5
    -3
      ge/hybrid/executor/hybrid_model_async_executor.cc
  19. +0
    -4
      ge/hybrid/executor/hybrid_model_executor.cc
  20. +0
    -1
      ge/hybrid/executor/hybrid_model_pipeline_executor.cc
  21. +9
    -2
      ge/hybrid/executor/worker/task_compile_engine.cc
  22. +1
    -0
      ge/hybrid/model/hybrid_model_builder.cc
  23. +23
    -2
      ge/hybrid/model/node_item.cc
  24. +7
    -2
      ge/ir_build/option_utils.cc
  25. +4
    -3
      ge/offline/main.cc
  26. +68
    -85
      ge/single_op/single_op_model.cc
  27. +5
    -2
      ge/single_op/single_op_model.h
  28. +5
    -1
      ge/single_op/task/op_task.h
  29. +24
    -1
      ge/single_op/task/tbe_task_builder.cc
  30. +1
    -0
      ge/single_op/task/tbe_task_builder.h
  31. +1
    -1
      metadef
  32. +1
    -1
      parser
  33. +15
    -0
      scripts/env/Dockerfile
  34. +2
    -2
      scripts/env/ge_env.sh
  35. +1
    -0
      tests/depends/cce/CMakeLists.txt
  36. +0
    -13
      tests/framework/CMakeLists.txt
  37. +19
    -3
      tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h
  38. +6
    -2
      tests/framework/easy_graph/src/layout/graph_layout.cc
  39. +37
    -0
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h
  40. +32
    -0
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h
  41. +32
    -17
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h
  42. +59
    -0
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h
  43. +2
    -4
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h
  44. +26
    -0
      tests/framework/ge_graph_dsl/src/assert/assert_error.cc
  45. +34
    -0
      tests/framework/ge_graph_dsl/src/assert/check_utils.cc
  46. +31
    -0
      tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc
  47. +33
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h
  48. +79
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc
  49. +49
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h
  50. +32
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h
  51. +28
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc
  52. +41
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h
  53. +0
    -0
      tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc
  54. +12
    -5
      tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc
  55. +1
    -3
      tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc
  56. +3
    -9
      tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc
  57. +0
    -0
      tests/framework/ge_graph_dsl/src/vistor/ge_subgraph_vistor.cc
  58. +0
    -0
      tests/framework/ge_graph_dsl/src/vistor/graph_dsl.cc
  59. +1
    -1
      tests/framework/ge_graph_dsl/tests/CMakeLists.txt
  60. +129
    -0
      tests/framework/ge_graph_dsl/tests/check_graph_test.cc
  61. +16
    -28
      tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc
  62. +6
    -0
      tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc
  63. +25
    -22
      tests/framework/ge_graph_dsl/tests/test_main.cc
  64. +0
    -48
      tests/framework/utils/builder/graph_builder_utils.cc
  65. +0
    -55
      tests/framework/utils/builder/graph_builder_utils.h
  66. +1
    -1
      tests/st/testcase/CMakeLists.txt
  67. +49
    -78
      tests/st/testcase/test_framework_dummy.cc
  68. +4
    -16
      tests/st/testcase/test_ge_opt_info.cc
  69. +2
    -2
      tests/st/testcase/test_main.cc
  70. +1
    -0
      tests/ut/common/graph/CMakeLists.txt
  71. +1
    -0
      tests/ut/ge/CMakeLists.txt
  72. +3
    -1
      tests/ut/ge/graph/build/task_generator_unittest.cc
  73. +1
    -1
      tests/ut/ge/graph/passes/addn_pass_unittest.cc
  74. +115
    -0
      tests/ut/ge/graph/passes/base_pass_unittest.cc
  75. +45
    -0
      tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc
  76. +28
    -0
      tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc
  77. +12
    -0
      tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc
  78. +1
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc
  79. +1
    -0
      tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc
  80. +20
    -1
      tests/ut/ge/single_op/single_op_model_unittest.cc
  81. +1
    -0
      tests/ut/ge/single_op/single_op_task_unittest.cc
  82. +7
    -0
      third_party/fwkacllib/inc/external/runtime/rt_error_codes.h
  83. +4
    -4
      third_party/fwkacllib/inc/runtime/base.h
  84. +43
    -0
      third_party/fwkacllib/inc/runtime/config.h
  85. +5
    -0
      third_party/fwkacllib/inc/runtime/dev.h
  86. +35
    -2
      third_party/fwkacllib/inc/runtime/event.h
  87. +66
    -17
      third_party/fwkacllib/inc/runtime/kernel.h
  88. +11
    -0
      third_party/fwkacllib/inc/runtime/mem.h
  89. +11
    -2
      third_party/fwkacllib/inc/runtime/rt_model.h
  90. +30
    -1
      third_party/fwkacllib/inc/toolchain/prof_callback.h
  91. +32
    -30
      third_party/fwkacllib/inc/toolchain/slog.h
  92. +88
    -72
      third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h

+ 1
- 2
.clang-format View File

@@ -52,7 +52,6 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4 ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4 ContinuationIndentWidth: 4
Cpp11BracedListStyle: true Cpp11BracedListStyle: true
DerivePointerAlignment: true
DisableFormat: false DisableFormat: false
ExperimentalAutoDetectBinPacking: false ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true FixNamespaceComments: true
@@ -94,7 +93,7 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10 PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000 PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200 PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
PointerAlignment: Right
RawStringFormats: RawStringFormats:
- Language: Cpp - Language: Cpp
Delimiters: Delimiters:


+ 2
- 0
ge/common/CMakeLists.txt View File

@@ -95,6 +95,7 @@ target_link_libraries(ge_common PRIVATE
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>>
static_mmpa static_mmpa
-Wl,--no-as-needed -Wl,--no-as-needed
graph graph
@@ -155,6 +156,7 @@ target_link_libraries(ge_common_static PRIVATE
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>>
ascend_protobuf_static ascend_protobuf_static
json json
c_sec c_sec


+ 1
- 1
ge/common/dump/dump_properties.cc View File

@@ -163,7 +163,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDum
GELOGE(PARAM_INVALID, "[Check][Param] the dumpPath:%s, is invalid.", input.c_str()); GELOGE(PARAM_INVALID, "[Check][Param] the dumpPath:%s, is invalid.", input.c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }
if (mmAccess2(trusted_path, R_OK | W_OK) != EN_OK) {
if (mmAccess2(trusted_path, M_R_OK | M_W_OK) != EN_OK) {
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}),
std::vector<std::string>({ std::vector<std::string>({
"ge.exec.dumpPath", "ge.exec.dumpPath",


+ 1
- 0
ge/common/dump/exception_dumper.cc View File

@@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> &ex


uint64_t proto_size = dump_data.ByteSizeLong(); uint64_t proto_size = dump_data.ByteSizeLong();
std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]);
GE_CHECK_NOTNULL(proto_msg);
bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size);
if (!ret || proto_size == 0) { if (!ret || proto_size == 0) {
REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); REPORT_INNER_ERROR("E19999", "Serialize proto to string fail");


+ 4
- 0
ge/executor/CMakeLists.txt View File

@@ -186,6 +186,8 @@ target_include_directories(ge_executor SYSTEM PRIVATE
${CMAKE_BINARY_DIR}/proto/graphengine_protos ${CMAKE_BINARY_DIR}/proto/graphengine_protos
#### yellow zone #### #### yellow zone ####
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:${GE_DEPEND_DIR}/inc> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:${GE_DEPEND_DIR}/inc>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<TARGET_PROPERTY:runtime_headers,INTERFACE_INCLUDE_DIRECTORIES>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<TARGET_PROPERTY:cce_headers,INTERFACE_INCLUDE_DIRECTORIES>>
#### blue zone #### #### blue zone ####
$<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc> $<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc>
$<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> $<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain>
@@ -251,6 +253,8 @@ target_link_libraries(ge_executor_shared PRIVATE
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:cce_headers>>
-Wl,--no-as-needed -Wl,--no-as-needed
ge_common ge_common
runtime runtime


+ 1
- 1
ge/ge_runtime/task/label_goto_task.cc View File

@@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() {
return false; return false;
} }


rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size);
rt_ret = rtLabelListCpy(reinterpret_cast<void**>(label_list.data()), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false; return false;


+ 1
- 2
ge/graph/build/model_builder.cc View File

@@ -32,7 +32,6 @@
#include "graph/ge_attr_value.h" #include "graph/ge_attr_value.h"
#include "graph/ge_context.h" #include "graph/ge_context.h"
#include "external/graph/ge_error_codes.h" #include "external/graph/ge_error_codes.h"
#include "graph/manager/graph_mem_allocator.h"
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"
#include "graph/optimize/common/params.h" #include "graph/optimize/common/params.h"
#include "external/graph/types.h" #include "external/graph/types.h"
@@ -707,7 +706,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) {
if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) {
GE_CHECK_NOTNULL(kernel_buffer.GetData()); GE_CHECK_NOTNULL(kernel_buffer.GetData());
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize());
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data));
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data));
GE_CHECK_NOTNULL(tbe_kernel); GE_CHECK_NOTNULL(tbe_kernel);
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(),
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());


+ 20
- 24
ge/graph/build/task_generator.cc View File

@@ -793,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
GELOGI("Start AutoFindBpOpIndex"); GELOGI("Start AutoFindBpOpIndex");
NodePtr bp_node = nullptr; NodePtr bp_node = nullptr;
uint32_t current_idx = 0; uint32_t current_idx = 0;
uint32_t netoutput_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
OpDescPtr op_desc = node->GetOpDesc(); OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
@@ -811,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) {
if (bp_node == nullptr) { if (bp_node == nullptr) {
bp_node = node; bp_node = node;
netoutput_idx = current_idx - 1;
} }
} }
if (graph->GetNeedIteration()) { if (graph->GetNeedIteration()) {
@@ -836,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (bp_node == nullptr) { if (bp_node == nullptr) {
GELOGW("not find bp_node."); GELOGW("not find bp_node.");
return SUCCESS; return SUCCESS;
} else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) {
profiling_point.bp_index = netoutput_idx;
GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx);
} else {
profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node);
} }


return SUCCESS;
return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index);
} }


uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const {
uint32_t last_bp = 0;
Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node,
uint32_t &bp_index) const {
bp_index = 0;
auto target_desc = target_node->GetOpDesc();
GE_CHECK_NOTNULL(target_desc);
OpDescPtr bp_op_desc = nullptr; OpDescPtr bp_op_desc = nullptr;
for (auto &in_anchor : bp_node->GetAllInDataAnchors()) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) {
continue;
}
auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc();
GE_CHECK_NOTNULL(out_node_desc);
if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) {
bp_op_desc = out_node_desc;
for (auto &in_node : target_node->GetInAllNodes()) {
GE_CHECK_NOTNULL(in_node);
auto in_node_desc = in_node->GetOpDesc();
GE_CHECK_NOTNULL(in_node_desc);
if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) &&
(in_node_desc->GetStreamId() == target_desc->GetStreamId())){
bp_op_desc = in_node_desc;
} }
GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId());
} }


if (bp_op_desc == nullptr) { if (bp_op_desc == nullptr) {
return last_bp;
GELOGI("Did not find bp node.");
return SUCCESS;
} }
uint32_t current_idx = 0; uint32_t current_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
@@ -871,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
current_idx++; current_idx++;
if (op_desc->GetName() == bp_op_desc->GetName()) { if (op_desc->GetName() == bp_op_desc->GetName()) {
last_bp = current_idx;
GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp);
bp_index = current_idx;
GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index);
break; break;
} }
} }
return last_bp;
GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(),
bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId());
return SUCCESS;
} }


Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,


+ 1
- 1
ge/graph/build/task_generator.h View File

@@ -116,7 +116,7 @@ class TaskGenerator {
Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const;
Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
vector<uint32_t> &all_reduce_nodes) const; vector<uint32_t> &all_reduce_nodes) const;
uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const;
Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const;


Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,
ProfilingPoint &profiling_point) const; ProfilingPoint &profiling_point) const;


+ 2
- 0
ge/graph/load/model_manager/task_info/kernel_task_info.cc View File

@@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);


args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
GE_CHECK_NOTNULL(args_addr);
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) { if (sec_ret != EOK) {
REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret);
@@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k


// copy args to new host memory // copy args to new host memory
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
GE_CHECK_NOTNULL(args_addr);
GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_)
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) { if (sec_ret != EOK) {


+ 0
- 49
ge/graph/manager/graph_var_manager.cc View File

@@ -194,35 +194,6 @@ ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_na
return SUCCESS; return SUCCESS;
} }


ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
GE_CHECK_NOTNULL(base_ptr);
GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());

VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;

return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
var_broadcast_info.input_size, session_id_);
}

ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());

VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
// subgraph base_ptr could be nullptr, task it as base 0
uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;

return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
var_tensor_desc, session_id_);
}

ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
}

bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }


rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
@@ -638,16 +609,6 @@ bool VarManager::IsVarExist(const std::string &var_name) {
return var_resource_->IsVarExist(var_name); return var_resource_->IsVarExist(var_name);
} }


ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
uint8_t *base_ptr) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) {
GELOGW("VarManager has not been init.");
return ge::INTERNAL_ERROR;
}
return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr);
}

ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
@@ -701,16 +662,6 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt
return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
} }


ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) {
GELOGW("VarManager has not been init.");
return ge::INTERNAL_ERROR;
}
return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr);
}

bool VarManager::IsVarAddr(const int64_t &offset) { bool VarManager::IsVarAddr(const int64_t &offset) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) { if (var_resource_ == nullptr) {


+ 0
- 15
ge/graph/manager/graph_var_manager.h View File

@@ -118,15 +118,6 @@ class VarResource {


ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info);


ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr);

ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr);

ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
uint8_t *base_ptr);

Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) {
GELOGW("Var name: %s has already set.", var_name.c_str()); GELOGW("Var name: %s has already set.", var_name.c_str());
@@ -234,16 +225,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {


ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr);


ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
uint8_t *base_ptr);

ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info);


ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info);


ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
uint8_t *base_ptr);

ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc);


ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc);


+ 0
- 66
ge/graph/manager/trans_var_data_utils.cc View File

@@ -415,72 +415,6 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src,
return SUCCESS; return SUCCESS;
} }
} // namespace } // namespace
Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) {
GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr.");
uint8_t *src_host_addr = nullptr;
int64_t src_addr_size = 0;
GE_MAKE_GUARD_RTMEM(src_host_addr);
GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id));

GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size);
GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED,
"[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld",
src_addr_size, dst_addr_size);

GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE));
return SUCCESS;
}

Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name,
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. ");
uint8_t *host_addr = nullptr;
GE_MAKE_GUARD_RTMEM(host_addr);
GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size));
GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST));

GE_CHK_STATUS_RET(
SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id));

return SUCCESS;
}

Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) {
GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed");

uint8_t *src_addr = nullptr;
GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr));
uint8_t *mem_addr =
src_addr -
static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
static_cast<int64_t>(
reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size));

GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST));

GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size);
return SUCCESS;
}

Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size,
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
uint8_t *dst_addr = nullptr;
GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr));
uint8_t *mem_addr =
dst_addr -
static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
static_cast<int64_t>(
reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE));

GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size);

return SUCCESS;
}

Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes,
uint64_t session_id, uint64_t session_id,
rtContext_t context, rtContext_t context,


+ 0
- 11
ge/graph/manager/trans_var_data_utils.h View File

@@ -29,11 +29,6 @@
namespace ge { namespace ge {
class TransVarDataUtils { class TransVarDataUtils {
public: public:
static ge::Status SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id_);
static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name,
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_);

static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes, static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes,
uint64_t session_id, uint64_t session_id,
rtContext_t context, rtContext_t context,
@@ -41,12 +36,6 @@ class TransVarDataUtils {
uint32_t thread_num = 16); uint32_t thread_num = 16);


static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id);

private:
static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_);
static ge::Status SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size,
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_);
}; };
} // namespace ge } // namespace ge




+ 25
- 5
ge/graph/passes/infer_value_range_pass.cc View File

@@ -301,12 +301,26 @@ graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc,
GeTensorPtr &output_ptr) { GeTensorPtr &output_ptr) {
std::vector<std::pair<int64_t, int64_t>> value_range; std::vector<std::pair<int64_t, int64_t>> value_range;
(void)tensor_desc.GetValueRange(value_range); (void)tensor_desc.GetValueRange(value_range);
if (static_cast<int64_t>(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) {
GELOGW("Value range of input %s is invalid.", tensor_desc.GetName().c_str());
size_t value_range_data_num = value_range.size();
auto tensor_shape = tensor_desc.GetShape();
bool value_range_and_tensor_shape_matched = true;
if (tensor_shape.IsScalar()){
// scalar tensor has only one value_range pair
if (value_range_data_num != 1) {
value_range_and_tensor_shape_matched = false;
}
} else {
// normal tensor, value_range size is equal to tensor shape size.
if (static_cast<int64_t>(value_range_data_num) != tensor_shape.GetShapeSize()) {
value_range_and_tensor_shape_matched = false;
}
}
if (!value_range_and_tensor_shape_matched) {
GELOGW("Input %s value range and tensor shape do not match. Value range size is %zu, tensor shape is %s.",
tensor_desc.GetName().c_str(), value_range_data_num, formats::ShapeToString(tensor_shape).c_str());
return GRAPH_PARAM_INVALID; return GRAPH_PARAM_INVALID;
} }


size_t value_range_data_num = value_range.size();
unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]()); unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]());
if (buf == nullptr) { if (buf == nullptr) {
REPORT_INNER_ERROR("E19999", "New buf failed"); REPORT_INNER_ERROR("E19999", "New buf failed");
@@ -494,10 +508,16 @@ void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, co
GELOGI("Output tensor of cpu kernel does not have data, no way to set value range."); GELOGI("Output tensor of cpu kernel does not have data, no way to set value range.");
return; return;
} }
for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) {
auto left_tensor_shape = left_tensor->GetTensorDesc().GetShape();
for (auto j = 0; j < left_tensor_shape.GetShapeSize(); ++j) {
auto left = static_cast<int64_t>(*(x + j)); auto left = static_cast<int64_t>(*(x + j));
auto right = static_cast<int64_t>(*(y + j)); auto right = static_cast<int64_t>(*(y + j));
value_range.emplace_back(std::make_pair(left, right));
value_range.emplace_back(left, right);
}

if (left_tensor_shape.IsScalar()) {
GELOGD("When inferring value range, output tensors of cpu kernel are scalar tensors.");
value_range.emplace_back(static_cast<int64_t>(*x), static_cast<int64_t>(*y));
} }
} }
} // namespace ge } // namespace ge

+ 1
- 0
ge/graph/preprocess/insert_op/util_insert_aipp_op.cc View File

@@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std:
} }


std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams()); std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams());
GE_CHECK_NOTNULL(aipp_params);
ge::GeAttrValue::NAMED_ATTRS aipp_attr; ge::GeAttrValue::NAMED_ATTRS aipp_attr;
GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST,
"[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); "[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str());


+ 1
- 1
ge/graph/preprocess/multi_batch_copy_graph.cc View File

@@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_
auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims();
if (!IsAllDimsPositive(dims)) { if (!IsAllDimsPositive(dims)) {
REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s",
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s",
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); node->GetName().c_str(), formats::ShapeToString(dims).c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;


+ 5
- 3
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -295,13 +295,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData &current_data, Hy
} }
} }
tensor_desc->SetShape(shape); tensor_desc->SetShape(shape);
args.input_desc[input_index] = tensor_desc;
GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str());
GELOGD("Update shape[%s] of input[%zu] to [%s]",
shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str());
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size),
"[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," "[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size,"
"index = %zu, shape = [%s], model_id = %u.", "index = %zu, shape = [%s], model_id = %u.",
input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); input_index, tensor_desc->GetShape().ToString().c_str(), model_id_);
GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size);
GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size);
TensorUtils::SetSize(*tensor_desc, tensor_size);
args.input_desc[input_index] = tensor_desc;
} }


GE_CHECK_GE(tensor_size, 0); GE_CHECK_GE(tensor_size, 0);


+ 0
- 4
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id,
} }


HybridModelExecutor::~HybridModelExecutor() { HybridModelExecutor::~HybridModelExecutor() {
if (context_.rt_gen_context != nullptr) {
(void) rtCtxDestroy(context_.rt_gen_context);
}
} }


Status HybridModelExecutor::Init() { Status HybridModelExecutor::Init() {
@@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() {


Status HybridModelExecutor::InitExecutionContext() { Status HybridModelExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context));
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));


context_.global_step = model_->GetGlobalStep(); context_.global_step = model_->GetGlobalStep();


+ 0
- 1
ge/hybrid/executor/hybrid_model_pipeline_executor.cc View File

@@ -191,7 +191,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin
} }


Status StageExecutor::InitExecutionContext() { Status StageExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));


context_.model = model_; context_.model = model_;


+ 9
- 2
ge/hybrid/executor/worker/task_compile_engine.cc View File

@@ -21,10 +21,17 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) {
const auto &node_item = *node_state.GetNodeItem();
GE_CHECK_NOTNULL(context); GE_CHECK_NOTNULL(context);
rtContext_t rt_gen_context = nullptr;
GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0));
std::function<void()> callback = [&]() {
(void) rtCtxDestroy(rt_gen_context);
GE_CHK_RT(rtCtxSetCurrent(context->rt_context));
};
GE_MAKE_GUARD(rt_gen_context, callback);

const auto &node_item = *node_state.GetNodeItem();
RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start");
GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context));


if (context->ge_context != nullptr) { if (context->ge_context != nullptr) {
GetThreadLocalContext() = *context->ge_context; GetThreadLocalContext() = *context->ge_context;


+ 1
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -1044,6 +1044,7 @@ Status HybridModelBuilder::InitConstantOps() {
} else { } else {
var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0));
} }
GE_CHECK_NOTNULL(var_tensor);
} else { } else {
GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor));
GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize());


+ 23
- 2
ge/hybrid/model/node_item.cc View File

@@ -24,6 +24,8 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
namespace { namespace {
const uint8_t kMaxTransCount = 3;
const uint32_t kTransOpIoSize = 1;
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char *const kNodeTypeRetVal = "_RetVal"; const char *const kNodeTypeRetVal = "_RetVal";
const std::set<std::string> kControlOpTypes{ const std::set<std::string> kControlOpTypes{
@@ -39,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{
MERGE, REFMERGE, STREAMMERGE MERGE, REFMERGE, STREAMMERGE
}; };


bool IsEnterFeedNode(NodePtr node) {
// For: Enter -> node
// For: Enter -> Cast -> node
// For: Enter -> TransData -> Cast -> node
for (uint8_t i = 0; i < kMaxTransCount; ++i) {
if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str());
return true;
}

const auto all_nodes = node->GetInDataNodes();
if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) {
return false;
}
node = all_nodes.at(0);
}
return false;
}

Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0; uint32_t parent_index = 0;
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
@@ -399,7 +420,7 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_anchors.emplace(anchor_index); data_anchors.emplace(anchor_index);
} }
// If Enter feed Not Merge, take as root Node. // If Enter feed Not Merge, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) {
auto &data_anchors = node_item->enter_data_[this]; auto &data_anchors = node_item->enter_data_[this];
data_anchors.emplace(anchor_index); data_anchors.emplace(anchor_index);
} }
@@ -419,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
node_item->root_ctrl_.emplace(this); node_item->root_ctrl_.emplace(this);
} }
// If Enter feed control signal, take as root Node. // If Enter feed control signal, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
node_item->enter_ctrl_.emplace(this); node_item->enter_ctrl_.emplace(this);
} }
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());


+ 7
- 2
ge/ir_build/option_utils.cc View File

@@ -50,6 +50,8 @@ const std::set<std::string> kBufferOptimizeSupportOption = {"l1_optimize", "l2_o
const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision"; const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL = "high_precision_for_all";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL = "high_performance_for_all";
const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\"";
const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\"";
const char *const kSplitError1 = "size not equal to 2 split by \":\""; const char *const kSplitError1 = "size not equal to 2 split by \":\"";
@@ -57,7 +59,8 @@ const char *const kEmptyError = "can not be empty";
const char *const kFloatNumError = "exist float number"; const char *const kFloatNumError = "exist float number";
const char *const kDigitError = "is not digit"; const char *const kDigitError = "is not digit";
const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]";
const char *const kSelectImplmodeError = "only support high_performance, high_precision";
const char *const kSelectImplmodeError = "only support high_performance, high_precision, "
"high_precision_for_all, high_performance_for_all";
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \"";
const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\"";
const char *const kKeepDtypeError = "file not found"; const char *const kKeepDtypeError = "file not found";
@@ -782,7 +785,9 @@ Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::
op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT; op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT;
} else { } else {
if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT && if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT &&
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) {
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON &&
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL &&
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--op_select_implmode", op_select_implmode.c_str(), {"--op_select_implmode", op_select_implmode.c_str(),
kSelectImplmodeError}); kSelectImplmodeError});


+ 4
- 3
ge/offline/main.cc View File

@@ -143,7 +143,8 @@ DEFINE_string(output_type, "",


DEFINE_string(op_select_implmode, "", DEFINE_string(op_select_implmode, "",
"Optional; op select implmode! " "Optional; op select implmode! "
"Support high_precision, high_performance.");
"Support high_precision, high_performance, "
"high_precision_for_all, high_performance_for_all.");


DEFINE_string(optypelist_for_implmode, "", DEFINE_string(optypelist_for_implmode, "",
"Optional; Nodes need use implmode selected in op_select_implmode " "Optional; Nodes need use implmode selected in op_select_implmode "
@@ -311,8 +312,8 @@ class GFlagUtils {
"scenarios by using a configuration file.\n" "scenarios by using a configuration file.\n"
" --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n"
" --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n" " --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n"
" --op_select_implmode Set op select implmode. Support high_precision, high_performance. "
"default: high_performance\n"
" --op_select_implmode Set op select implmode. Support high_precision, high_performance, "
"high_precision_for_all, high_performance_for_all. default: high_performance\n"
" --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n" " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n"
" Separate multiple nodes with commas (,). Use double quotation marks (\") " " Separate multiple nodes with commas (,). Use double quotation marks (\") "
"to enclose each argument. E.g.: \"node_name1,node_name2\"\n" "to enclose each argument. E.g.: \"node_name1,node_name2\"\n"


+ 68
- 85
ge/single_op/single_op_model.cc View File

@@ -95,35 +95,6 @@ Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_ho
} }
return SUCCESS; return SUCCESS;
} }

Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) {
bool is_infer_depend = false;
bool is_host_mem = false;
GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed.");
bool need_d2h_cpy = is_infer_depend && !is_host_mem;
auto tasks = ge_model->GetModelTaskDefPtr()->task();
int32_t kernel_task_num = 0;
for (int i = 0; i < tasks.size(); ++i) {
auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type());
if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) {
const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() :
tasks[i].kernel_with_handle().context();
auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
if (need_d2h_cpy) {
flag = true;
return SUCCESS;
}
kernel_task_num++;
if (kernel_task_num > 1) {
flag = true;
return SUCCESS;
}
}
}
}
return SUCCESS;
}
} // namespace } // namespace


SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size) SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size)
@@ -558,14 +529,15 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) {
return BuildTaskList(&resource, single_op); return BuildTaskList(&resource, single_op);
} }


Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, const TaskDef &task_def,
DynamicSingleOp &single_op) {
auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() :
task_def.kernel_with_handle().context();
Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) {
auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);


auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph());
GE_CHECK_NOTNULL(compute_graph);
single_op.compute_graph_ = compute_graph;
if (tbe_tasks_.size() > 0) {
const auto &task_def = tbe_tasks_[0];
GELOGD("Building TBE task."); GELOGD("Building TBE task.");
TbeOpTask *tbe_task = nullptr; TbeOpTask *tbe_task = nullptr;
GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task)); GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task));
@@ -575,71 +547,81 @@ Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, cons
tbe_task->stream_resource_ = stream_resource; tbe_task->stream_resource_ = stream_resource;
} }
single_op.op_task_.reset(tbe_task); single_op.op_task_.reset(tbe_task);
} else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) {
GELOGD("Building AICPU_CC task");
OpTask *task = nullptr;
uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++;
GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id);
GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id));
task->SetModelArgs(model_name_, model_id_);
single_op.op_task_.reset(task);
} else {
GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID,
"[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u",
context.kernel_type());
REPORT_INNER_ERROR("E19999",
"BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.",
context.kernel_type());
return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID;
}
return SUCCESS;
}

Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) {
auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);

auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph());
GE_CHECK_NOTNULL(compute_graph);
single_op.compute_graph_ = compute_graph;
auto tasks = ge_model->GetModelTaskDefPtr()->task();
for (int i = 0; i < tasks.size(); ++i) {
const TaskDef &task_def = tasks[i];
GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(),
task_def.DebugString().c_str());
} else if (aicpu_tasks_.size() > 0) {
const auto &task_def = aicpu_tasks_[0];
auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) {
if (single_op.op_task_ != nullptr) {
GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks.");
REPORT_INNER_ERROR("E19999",
"BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks.");
return ACL_ERROR_GE_OP_TASK_TYPE_INVALID;
}
GE_CHK_STATUS_RET_NOLOG(BuildModelTaskKernel(stream_resource, task_def, single_op));
if (task_type == RT_MODEL_TASK_KERNEL) {
GELOGD("Building AICPU_CC task");
OpTask *task = nullptr;
uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++;
GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id);
GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id));
task->SetModelArgs(model_name_, model_id_);
single_op.op_task_.reset(task);
} else if (task_type == RT_MODEL_TASK_KERNEL_EX) { } else if (task_type == RT_MODEL_TASK_KERNEL_EX) {
if (single_op.op_task_ != nullptr) {
GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks.");
REPORT_INNER_ERROR("E19999",
"BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks.");
return ACL_ERROR_GE_OP_TASK_TYPE_INVALID;
}
GELOGD("Building AICPU_TF task"); GELOGD("Building AICPU_TF task");
AiCpuTask *aicpu_task = nullptr; AiCpuTask *aicpu_task = nullptr;
uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++;
GELOGI("Build dynamic singleOp TfTask, kernel_id = %lu", dynamic_singleop_kernel_id); GELOGI("Build dynamic singleOp TfTask, kernel_id = %lu", dynamic_singleop_kernel_id);
GE_CHK_STATUS_RET_NOLOG(BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, dynamic_singleop_kernel_id)); GE_CHK_STATUS_RET_NOLOG(BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, dynamic_singleop_kernel_id));
if (aicpu_task->GetUnknownType() == DEPEND_COMPUTE) { if (aicpu_task->GetUnknownType() == DEPEND_COMPUTE) {
if (i >= tasks.size() - 1) {
if (aicpu_tasks_.size() < 2) {
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Task]The copy task of the fourth operator was not found."); GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Task]The copy task of the fourth operator was not found.");
REPORT_INNER_ERROR("E19999", "The copy task of the fourth operator was not found."); REPORT_INNER_ERROR("E19999", "The copy task of the fourth operator was not found.");
return ACL_ERROR_GE_PARAM_INVALID; return ACL_ERROR_GE_PARAM_INVALID;
} }
++i;
const TaskDef &copy_task_def = tasks[i];
const TaskDef &copy_task_def = aicpu_tasks_[1];
GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex())); GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex()));
} }
aicpu_task->SetModelArgs(model_name_, model_id_); aicpu_task->SetModelArgs(model_name_, model_id_);
single_op.op_task_.reset(aicpu_task); single_op.op_task_.reset(aicpu_task);
}
}
return SUCCESS;
}

Status SingleOpModel::NeedHybridModel(GeModelPtr &ge_model, bool &need_hybrid_model) {
bool is_infer_depend = false;
bool is_host_mem = false;
GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed.");
bool need_d2h_cpy = is_infer_depend && !is_host_mem;
bool aicpu_multi_task = tbe_tasks_.size() >= 1 && aicpu_tasks_.size() >= 1;
bool aicore_multi_task = tbe_tasks_.size() > 1;
need_hybrid_model = need_d2h_cpy || aicore_multi_task || aicpu_multi_task;
return SUCCESS;
}

Status SingleOpModel::ParseTasks() {
auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);

auto tasks = ge_model->GetModelTaskDefPtr()->task();
for (int i = 0; i < tasks.size(); ++i) {
TaskDef &task_def = tasks[i];
GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(),
task_def.DebugString().c_str());
auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
if (task_type == RT_MODEL_TASK_KERNEL) {
const auto &kernel_def = task_def.kernel();
const auto &context = kernel_def.context();
auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
tbe_tasks_.emplace_back(task_def);
} else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) {
aicpu_tasks_.emplace_back(task_def);
} else {
GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID,
"[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u",
context.kernel_type());
REPORT_INNER_ERROR("E19999",
"BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.",
context.kernel_type());
return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID;
}
} else if (task_type == RT_MODEL_TASK_ALL_KERNEL) {
tbe_tasks_.emplace_back(task_def);
} else if (task_type == RT_MODEL_TASK_KERNEL_EX) {
aicpu_tasks_.emplace_back(task_def);
} else { } else {
// skip // skip
GELOGD("Skip task type: %d", static_cast<int>(task_type)); GELOGD("Skip task type: %d", static_cast<int>(task_type));
@@ -654,6 +636,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp &
GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource));
model_params_.memory_size = UINT64_MAX; model_params_.memory_size = UINT64_MAX;
model_params_.graph_is_dynamic = true; model_params_.graph_is_dynamic = true;
GE_CHK_STATUS_RET(ParseTasks(), "[Parse][Tasks] failed.");


auto ge_model = model_helper_.GetGeModel(); auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model); GE_CHECK_NOTNULL(ge_model);


+ 5
- 2
ge/single_op/single_op_model.h View File

@@ -71,13 +71,16 @@ class SingleOpModel {
Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task); Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task);
Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id); Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id);
Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id); Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id);
Status BuildModelTaskKernel(StreamResource *stream_resource, const domi::TaskDef &task_def,
DynamicSingleOp &single_op);


static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam &param); static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam &param);
void ParseArgTable(OpTask *task, SingleOp &op); void ParseArgTable(OpTask *task, SingleOp &op);
Status InitHybridModelExecutor(const StreamResource &resource, const GeModelPtr &ge_model, SingleOp &single_op); Status InitHybridModelExecutor(const StreamResource &resource, const GeModelPtr &ge_model, SingleOp &single_op);
Status SetHostMemTensor(DynamicSingleOp &single_op); Status SetHostMemTensor(DynamicSingleOp &single_op);
Status NeedHybridModel(GeModelPtr &ge_model, bool &flag);
Status ParseTasks();

std::vector<domi::TaskDef> tbe_tasks_;
std::vector<domi::TaskDef> aicpu_tasks_;


std::string model_name_; std::string model_name_;
uint32_t model_id_ = 0; uint32_t model_id_ = 0;


+ 5
- 1
ge/single_op/task/op_task.h View File

@@ -33,6 +33,10 @@
#include "register/op_tiling.h" #include "register/op_tiling.h"


namespace ge { namespace ge {
namespace {
const int kAddressNum = 2;
} // namespace

class StreamResource; class StreamResource;
struct SingleOpModelParam; struct SingleOpModelParam;
class OpTask { class OpTask {
@@ -264,7 +268,7 @@ class MemcpyAsyncTask : public OpTask {
friend class SingleOpModel; friend class SingleOpModel;
friend class RtsKernelTaskBuilder; friend class RtsKernelTaskBuilder;


uintptr_t addresses_[2];
uintptr_t addresses_[kAddressNum];
size_t dst_max_; size_t dst_max_;
size_t count_; size_t count_;
rtMemcpyKind_t kind_; rtMemcpyKind_t kind_;


+ 24
- 1
ge/single_op/task/tbe_task_builder.cc View File

@@ -104,7 +104,7 @@ Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bi
binary.version = 0; binary.version = 0;
binary.data = kernel_bin.GetBinData(); binary.data = kernel_bin.GetBinData();
binary.length = kernel_bin.GetBinDataSize(); binary.length = kernel_bin.GetBinDataSize();
binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC;
GE_CHK_STATUS_RET_NOLOG(GetMagic(binary.magic));
Status ret = 0; Status ret = 0;
if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) { if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) {
ret = rtRegisterAllKernel(&binary, bin_handle); ret = rtRegisterAllKernel(&binary, bin_handle);
@@ -416,4 +416,27 @@ Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) {
task.EnableDynamicSupport(node_, tiling_buffer, static_cast<uint32_t>(max_size)); task.EnableDynamicSupport(node_, tiling_buffer, static_cast<uint32_t>(max_size));
return SUCCESS; return SUCCESS;
} }

Status TbeTaskBuilder::GetMagic(uint32_t &magic) const {
std::string json_string;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc_, TVM_ATTR_NAME_MAGIC, json_string),
GELOGD("Get original type of session_graph_id."));
if (json_string == "RT_DEV_BINARY_MAGIC_ELF") {
magic = RT_DEV_BINARY_MAGIC_ELF;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") {
magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") {
magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE;
} else {
REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%s check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(),
op_desc_->GetType().c_str(), json_string.c_str());
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in op:%s(%s), value:%s check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(),
op_desc_->GetType().c_str(), json_string.c_str());
return PARAM_INVALID;
}
return SUCCESS;
}

} // namespace ge } // namespace ge

+ 1
- 0
ge/single_op/task/tbe_task_builder.h View File

@@ -105,6 +105,7 @@ class TbeTaskBuilder {
const SingleOpModelParam &param); const SingleOpModelParam &param);
Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam &param) const; Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam &param) const;
Status DoRegisterMeta(void *bin_handle); Status DoRegisterMeta(void *bin_handle);
Status GetMagic(uint32_t &magic) const;


static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name);




+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 9e4a51a9602195b82e326b853f5adbfefc3972b6
Subproject commit 9c9907b76a457f456072af96b8cbcfb7943beccc

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit 79536a196f89cf7a1f5852ff7304b9a7d7b12eff
Subproject commit 15a27afefe45f2abdb78787d629163aab9437599

+ 15
- 0
scripts/env/Dockerfile View File

@@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l


ENV PROJECT_HOME=/code/Turing/graphEngine ENV PROJECT_HOME=/code/Turing/graphEngine


RUN mkdir /var/run/sshd
RUN echo "root:root" | chpasswd
RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd

ENV NOTVISIBLE "in users profile"
RUN echo "export VISIBLE=now" >> /etc/profile

EXPOSE 22 7777

RUN useradd -ms /bin/bash debugger
RUN echo "debugger:ge123" | chpasswd

CMD ["/usr/sbin/sshd" "-D" "&"]

RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc



+ 2
- 2
scripts/env/ge_env.sh View File

@@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd)


DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/}
DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_}
DOCKER_IMAGE_TAG=ge_build_env.1.0.6
DOCKER_IMAGE_TAG=ge_build_env.1.0.9
DOCKER_IAMGE_NAME=joycode2art/turing DOCKER_IAMGE_NAME=joycode2art/turing
DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG}


@@ -61,7 +61,7 @@ function enter_docker_env(){
if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then
echo "please run 'ge env --pull' to download images first!" echo "please run 'ge env --pull' to download images first!"
elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then
$docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir}
$docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir}
elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then
$docker_cmd start ${DOCKER_BUILD_ENV_NAME} $docker_cmd start ${DOCKER_BUILD_ENV_NAME}
$docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir}


+ 1
- 0
tests/depends/cce/CMakeLists.txt View File

@@ -60,6 +60,7 @@ set(SRCS
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc"
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc"
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc"


+ 0
- 13
tests/framework/CMakeLists.txt View File

@@ -17,16 +17,3 @@ include(cmake/graphengine.cmake)
add_subdirectory(easy_graph) add_subdirectory(easy_graph)
add_subdirectory(ge_graph_dsl) add_subdirectory(ge_graph_dsl)
add_subdirectory(ge_running_env) add_subdirectory(ge_running_env)

file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS
"utils/*.cc"
)

add_library(framework STATIC ${UTILS_SRC})

target_include_directories(framework
PUBLIC utils/
)

set_target_properties(framework PROPERTIES CXX_STANDARD 11)
target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env)

+ 19
- 3
tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h View File

@@ -26,16 +26,32 @@ EG_NS_BEGIN


//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
namespace detail { namespace detail {
template<typename GRAPH_BUILDER>
template <typename GRAPH_BUILDER>
Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) {
GraphBuilder builder(name); GraphBuilder builder(name);
builderInDSL(builder); builderInDSL(builder);
return std::move(*builder); return std::move(*builder);
} }

struct GraphDefiner {
GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) {
name = specifiedName ? specifiedName : defaultName;
}

template <typename USER_BUILDER>
auto operator|(USER_BUILDER &&userBuilder) {
GraphBuilder graphBuilder{name};
std::forward<USER_BUILDER>(userBuilder)(graphBuilder);
return *graphBuilder;
}

private:
const char *name;
};

} // namespace detail } // namespace detail


#define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__)
#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER)
#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER)
#define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__
#define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__
#define CHAIN(...) DATA_CHAIN(__VA_ARGS__) #define CHAIN(...) DATA_CHAIN(__VA_ARGS__)


+ 6
- 2
tests/framework/easy_graph/src/layout/graph_layout.cc View File

@@ -16,10 +16,15 @@


#include "easy_graph/layout/graph_layout.h" #include "easy_graph/layout/graph_layout.h"
#include "easy_graph/layout/layout_executor.h" #include "easy_graph/layout/layout_executor.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h"
#include "easy_graph/graph/graph.h" #include "easy_graph/graph/graph.h"


EG_NS_BEGIN EG_NS_BEGIN


namespace {
GraphEasyExecutor default_executor;
}

void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) {
this->executor_ = &executor; this->executor_ = &executor;
options_ = opts; options_ = opts;
@@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) {


Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) {
const LayoutOption *options = opts ? opts : this->options_; const LayoutOption *options = opts ? opts : this->options_;
if (!executor_)
return EG_UNIMPLEMENTED;
if (!executor_) return static_cast<LayoutExecutor &>(default_executor).Layout(graph, options);
return executor_->Layout(graph, options); return executor_->Layout(graph, options);
} }




+ 37
- 0
tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef D52AA06185E34BBFB714FFBCDAB0D53A
#define D52AA06185E34BBFB714FFBCDAB0D53A

#include "ge_graph_dsl/ge.h"
#include <exception>
#include <string>

GE_NS_BEGIN

struct AssertError : std::exception {
AssertError(const char *file, int line, const std::string &info);

private:
const char *what() const noexcept override;

private:
std::string info;
};

GE_NS_END

#endif

+ 32
- 0
tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_31309AA0A4E44C009C22AD9351BF3410
#define INC_31309AA0A4E44C009C22AD9351BF3410

#include "ge_graph_dsl/ge.h"
#include "graph/compute_graph.h"

GE_NS_BEGIN

using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>;
struct CheckUtils {
static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun);
static void init();
};

GE_NS_END

#endif

tests/framework/utils/builder/tensor_builder_utils.cc → tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h View File

@@ -1,17 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensor_builder_utils.h"
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef C8B32320BD4943D588594B82FFBF2685
#define C8B32320BD4943D588594B82FFBF2685

#include <vector>
#include <string>
#include "ge_graph_dsl/ge.h"

GE_NS_BEGIN

struct FilterScopeGuard {
FilterScopeGuard(const std::vector<std::string> &);
~FilterScopeGuard();
};

GE_NS_END

#endif

+ 59
- 0
tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h View File

@@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6
#define AD954C4ADF5B44F5B1CC8BCD72EE9ED6

#include "ge_graph_dsl/ge.h"
#include "ge_graph_dsl/assert/check_utils.h"
#include "ge_graph_dsl/assert/assert_error.h"
#include "ge_graph_dsl/assert/filter_scope_guard.h"

GE_NS_BEGIN

#ifdef GTEST_MESSAGE_AT_
#define GRAPH_CHECK_MESSAGE(file, line, message) \
GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure)
#elif
#define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message)
#endif

namespace detail {
struct GraphAssert {
GraphAssert(const char *file, unsigned int line, const std::string &phase_id)
: file_(file), line_(line), phase_id_(phase_id) {}

void operator|(const ::GE_NS::GraphCheckFun &check_fun) {
bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun);
if (!ret) {
auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! ";
GRAPH_CHECK_MESSAGE(file_, line_, message.c_str());
}
}

private:
const char *file_;
unsigned int line_;
const std::string phase_id_;
};
} // namespace detail

#define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__});
#define CHECK_GRAPH(phase_id) \
::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph)

GE_NS_END

#endif

+ 2
- 4
tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h View File

@@ -33,14 +33,12 @@ struct OpDescCfg {
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
}; };


OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW,
OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224})
: type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {}


protected: protected:
OpType GetType() const {
return type_;
}
OpType GetType() const { return type_; }
OpType type_; OpType type_;
int in_cnt_; int in_cnt_;
int out_cnt_; int out_cnt_;


+ 26
- 0
tests/framework/ge_graph_dsl/src/assert/assert_error.cc View File

@@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_graph_dsl/assert/assert_error.h"

GE_NS_BEGIN

AssertError::AssertError(const char *file, int line, const std::string &info) {
this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info;
}

const char *AssertError::what() const noexcept { return info.c_str(); }

GE_NS_END

+ 34
- 0
tests/framework/ge_graph_dsl/src/assert/check_utils.cc View File

@@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_graph_dsl/assert/check_utils.h"
#include "graph/utils/dumper/ge_graph_dumper.h"
#include "ge_graph_default_checker.h"
#include "ge_graph_check_dumper.h"

GE_NS_BEGIN

bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) {
auto &dumper = dynamic_cast<GeGraphCheckDumper &>(GraphDumperRegistry::GetDumper());
return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun));
}

void CheckUtils::init() {
static GeGraphCheckDumper checkDumper;
GraphDumperRegistry::Register(checkDumper);
}

GE_NS_END

+ 31
- 0
tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc View File

@@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_graph_dsl/assert/filter_scope_guard.h"
#include "graph/utils/dumper/ge_graph_dumper.h"
#include "ge_dump_filter.h"

GE_NS_BEGIN

namespace {
GeDumpFilter &GetDumpFilter() { return dynamic_cast<GeDumpFilter &>(GraphDumperRegistry::GetDumper()); }
} // namespace

FilterScopeGuard::FilterScopeGuard(const std::vector<std::string> &filter) { GetDumpFilter().Update(filter); }

FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); }

GE_NS_END

+ 33
- 0
tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6
#define INC_4C6224E8F7474EF89B18CCB0E4B19FD6

#include <vector>
#include <string>
#include "ge_graph_dsl/ge.h"
#include "easy_graph/infra/keywords.h"

GE_NS_BEGIN

INTERFACE(GeDumpFilter) {
ABSTRACT(void Update(const std::vector<std::string> &));
ABSTRACT(void Reset());
};

GE_NS_END

#endif

+ 79
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc View File

@@ -0,0 +1,79 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_graph_check_dumper.h"
#include "graph/model.h"
#include "graph/buffer.h"
#include "graph/utils/graph_utils.h"
#include "ge_graph_default_checker.h"

GE_NS_BEGIN

GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); }

bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const {
auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix);
return (iter != suffixes_.end());
}

void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) {
if (!IsNeedDump(suffix)) {
return;
}
auto iter = buffers_.find(suffix);
if (iter != buffers_.end()) {
DumpGraph(graph, iter->second);
} else {
buffers_[suffix] = Buffer();
DumpGraph(graph, buffers_.at(suffix));
}
}

bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) {
auto iter = buffers_.find(checker.PhaseId());
if (iter == buffers_.end()) {
return false;
}
DoCheck(checker, iter->second);
return true;
}

void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) {
Model model("", "");
Model::Load(buffer.GetData(), buffer.GetSize(), model);
auto load_graph = model.GetGraph();
checker.Check(GraphUtils::GetComputeGraph(load_graph));
}

void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) {
Model model("", "");
buffer.clear();
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));
model.Save(buffer, true);
}

void GeGraphCheckDumper::Update(const std::vector<std::string> &new_suffixes_) {
suffixes_ = new_suffixes_;
buffers_.clear();
}

void GeGraphCheckDumper::Reset() {
static std::vector<std::string> default_suffixes_{"PreRunAfterBuild"};
suffixes_ = default_suffixes_;
buffers_.clear();
}

GE_NS_END

+ 49
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_8EFED0015C27464897BF64531355C810
#define INC_8EFED0015C27464897BF64531355C810

#include "ge_graph_dsl/ge.h"
#include "graph/utils/dumper/ge_graph_dumper.h"
#include "ge_dump_filter.h"
#include <string>

GE_NS_BEGIN

struct GeGraphChecker;

struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter {
GeGraphCheckDumper();
virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix);
bool CheckFor(const GeGraphChecker &checker);

private:
void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer);
void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer);

private:
void Update(const std::vector<std::string> &) override;
void Reset() override;
bool IsNeedDump(const std::string &suffix) const;

private:
std::map<std::string, ::GE_NS::Buffer> buffers_;
std::vector<std::string> suffixes_;
};

GE_NS_END

#endif

+ 32
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_5960A8F437324904BEE0690271258762
#define INC_5960A8F437324904BEE0690271258762

#include "ge_graph_dsl/ge.h"
#include "easy_graph/infra/keywords.h"
#include "graph/compute_graph.h"

GE_NS_BEGIN

INTERFACE(GeGraphChecker) {
ABSTRACT(const std::string &PhaseId() const);
ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const);
};

GE_NS_END

#endif

+ 28
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc View File

@@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_graph_default_checker.h"

GE_NS_BEGIN

GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun)
: phase_id_(phase_id), check_fun_(check_fun) {}

const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; }

void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); }

GE_NS_END

+ 41
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef BCF4D96BE9FC48938DE7B7E93B551C54
#define BCF4D96BE9FC48938DE7B7E93B551C54

#include "ge_graph_dsl/ge.h"
#include "ge_graph_checker.h"
#include "graph/compute_graph.h"

GE_NS_BEGIN

using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>;

struct GeGraphDefaultChecker : GeGraphChecker {
GeGraphDefaultChecker(const std::string &, const GraphCheckFun &);

private:
const std::string &PhaseId() const override;
void Check(const ge::ComputeGraphPtr &graph) const override;

private:
const std::string phase_id_;
const GraphCheckFun check_fun_;
};

GE_NS_END

#endif

tests/framework/ge_graph_dsl/src/op_desc_cfg_box.cc → tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc View File


tests/framework/ge_graph_dsl/src/op_desc_cfg_repo.cc → tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc View File

@@ -23,15 +23,22 @@ GE_NS_BEGIN


namespace { namespace {


#define OP_CFG(optype, ...) \
{ \
optype, OpDescCfg { \
optype, __VA_ARGS__ \
} \
#define OP_CFG(optype, ...) \
{ \
optype, OpDescCfg { optype, __VA_ARGS__ } \
} }


static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}),
OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(VARIABLE, 1, 1)}; OP_CFG(VARIABLE, 1, 1)};
} // namespace } // namespace



tests/framework/ge_graph_dsl/src/op_desc_ptr_box.cc → tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc View File

@@ -19,6 +19,4 @@


USING_GE_NS USING_GE_NS


OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const {
return op_;
}
OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; }

tests/framework/ge_graph_dsl/src/ge_graph_vistor.cc → tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc View File

@@ -36,17 +36,11 @@ GE_NS_BEGIN


GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {}


void GeGraphVisitor::reset(const ComputeGraphPtr &graph) {
build_graph_ = graph;
}
void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; }


Graph GeGraphVisitor::BuildGeGraph() const {
return GraphUtils::CreateGraphFromComputeGraph(build_graph_);
}
Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); }


ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const {
return build_graph_;
}
ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; }


Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) {
build_graph_->SetName(graph.GetName()); build_graph_->SetName(graph.GetName());

tests/framework/ge_graph_dsl/src/ge_subgraph_vistor.cc → tests/framework/ge_graph_dsl/src/vistor/ge_subgraph_vistor.cc View File


tests/framework/ge_graph_dsl/src/graph_dsl.cc → tests/framework/ge_graph_dsl/src/vistor/graph_dsl.cc View File


+ 1
- 1
tests/framework/ge_graph_dsl/tests/CMakeLists.txt View File

@@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE
) )
set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17)


target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl)
target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl)


include(CTest) include(CTest)
enable_testing() enable_testing()

+ 129
- 0
tests/framework/ge_graph_dsl/tests/check_graph_test.cc View File

@@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "easy_graph/layout/graph_layout.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h"
#include "ge_graph_dsl/graph_dsl.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/dumper/ge_graph_dumper.h"
#include "framework/common/types.h"
#include "ge_graph_dsl/assert/graph_assert.h"
#include "graph/model.h"
#include "graph/buffer.h"

USING_GE_NS

class CheckGraphTest : public testing::Test {
private:
EG_NS::GraphEasyExecutor executor;

protected:
void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); }
void TearDown() {}
};

TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

DUMP_GRAPH_WHEN("after_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build");

CHECK_GRAPH(after_build) {
ASSERT_EQ(graph->GetName(), "g1");
ASSERT_EQ(graph->GetAllNodesSize(), 2);
};
}

TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };
DEF_GRAPH(g2) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD));
};

DUMP_GRAPH_WHEN("before_build", "after_build");

GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build");

CHECK_GRAPH(before_build) {
ASSERT_EQ(graph->GetName(), "g1");
ASSERT_EQ(graph->GetAllNodesSize(), 2);
};

CHECK_GRAPH(after_build) {
ASSERT_EQ(graph->GetName(), "g2");
ASSERT_EQ(graph->GetAllNodesSize(), 3);
};
}

TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };
DEF_GRAPH(g2) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD));
};

DUMP_GRAPH_WHEN("before_build")

GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build");

CHECK_GRAPH(before_build) {
ASSERT_EQ(graph->GetName(), "g2");
ASSERT_EQ(graph->GetAllNodesSize(), 3);
};
}

TEST_F(CheckGraphTest, test_check_phases_is_work) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

DUMP_GRAPH_WHEN("before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build");
auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {});
ASSERT_FALSE(ret);
}

TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

DUMP_GRAPH_WHEN("before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build");

CHECK_GRAPH(before_build) {
ASSERT_EQ(graph->GetName(), "g1");
ASSERT_EQ(graph->GetAllNodesSize(), 2);
};
}

TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };
auto ge_graph = ToGeGraph(g1);

ge::Model model("", "");
model.SetGraph(ge_graph);
Buffer buffer;
model.Save(buffer, true);

ge::Model loadModel("", "");
Model::Load(buffer.GetData(), buffer.GetSize(), loadModel);
auto load_graph = loadModel.GetGraph();

ASSERT_EQ(load_graph.GetName(), "g1");
ASSERT_EQ(load_graph.GetAllNodes().size(), 2);
}

+ 16
- 28
tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc View File

@@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test {
EG_NS::GraphEasyExecutor executor; EG_NS::GraphEasyExecutor executor;


protected: protected:
void SetUp() {
EG_NS::GraphLayout::GetInstance().Config(executor, nullptr);
}
void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); }


void TearDown() {} void TearDown() {}
}; };


TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) {
DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD));
});
DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);
auto computeGraph = ToComputeGraph(g1); auto computeGraph = ToComputeGraph(g1);
@@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) {
} }


TEST_F(GraphDslTest, test_build_graph_with_name) { TEST_F(GraphDslTest, test_build_graph_with_name) {
DEF_GRAPH(g1, "sample_graph") {
CHAIN(NODE("data1", DATA)->NODE("add", ADD));
});
DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) {
auto data = std::make_shared<OpDesc>("data1", DATA); auto data = std::make_shared<OpDesc>("data1", DATA);
auto add = std::make_shared<OpDesc>("Add", ADD); auto add = std::make_shared<OpDesc>("Add", ADD);
CHAIN(NODE(data)->NODE(add)); CHAIN(NODE(data)->NODE(add));
});
};


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) {
auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1);
auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1);
CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); CHAIN(NODE("data1", datCfg)->NODE("add", addCfg));
});
};


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) {
} }


TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) {
DEF_GRAPH(g1) {
CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1)));
});
DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); };


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) {
} }


TEST_F(GraphDslTest, test_build_from_control_chain) { TEST_F(GraphDslTest, test_build_from_control_chain) {
DEF_GRAPH(g1) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
});
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) {
} }


TEST_F(GraphDslTest, test_build_from_data_chain) { TEST_F(GraphDslTest, test_build_from_data_chain) {
DEF_GRAPH(g1) {
DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
});
DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) {
DEF_GRAPH(g1) { DEF_GRAPH(g1) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add"));
});
};


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) {
DEF_GRAPH(g1) { DEF_GRAPH(g1) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add"));
});
};


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) {
DEF_GRAPH(g1) { DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data2", DATA)->NODE("add")); CHAIN(NODE("data2", DATA)->NODE("add"));
});
};


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) {
->NODE("Add4") ->NODE("Add4")
->NODE("Add5") ->NODE("Add5")
->NODE("net_output", NETOUTPUT)); ->NODE("net_output", NETOUTPUT));
});
};


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) {
CHAIN(NODE("add2")->NODE("net_output")); CHAIN(NODE("add2")->NODE("net_output"));
CHAIN(NODE("add3")->NODE("net_output")); CHAIN(NODE("add3")->NODE("net_output"));
CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3"));
});
};


auto geGraph = ToGeGraph(g1); auto geGraph = ToGeGraph(g1);


@@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) {
DEF_GRAPH(sub_1) { DEF_GRAPH(sub_1) {
CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT));
CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); CHAIN(NODE("const_5", CONSTANTOP)->NODE("less"));
});
};


DEF_GRAPH(sub_2) { DEF_GRAPH(sub_2) {
CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT));
CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul"));
});
};


DEF_GRAPH(g1) { DEF_GRAPH(g1) {
CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT));
CHAIN(NODE("data_i", DATA)->NODE("while")); CHAIN(NODE("data_i", DATA)->NODE("while"));
});
};


sub_1.Layout(); sub_1.Layout();
sub_2.Layout(); sub_2.Layout();


+ 6
- 0
tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc View File

@@ -30,5 +30,11 @@ REGISTER_OPTYPE_DEFINE(MUL, "Mul");
REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput");
REGISTER_OPTYPE_DEFINE(ADD, "Add"); REGISTER_OPTYPE_DEFINE(ADD, "Add");
REGISTER_OPTYPE_DEFINE(WHILE, "While"); REGISTER_OPTYPE_DEFINE(WHILE, "While");
REGISTER_OPTYPE_DEFINE(ENTER, "Enter");
REGISTER_OPTYPE_DEFINE(MERGE, "Merge");
REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond");
REGISTER_OPTYPE_DEFINE(SWITCH, "Switch");
REGISTER_OPTYPE_DEFINE(EXIT, "Exit");
REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration");


GE_NS_END GE_NS_END

tests/framework/utils/builder/tensor_builder_utils.h → tests/framework/ge_graph_dsl/tests/test_main.cc View File

@@ -1,22 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H
#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H
class tensor_builder_utils {};
#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "ge_graph_dsl/assert/check_utils.h"

int main(int argc, char **argv) {
::GE_NS::CheckUtils::init();
testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}

+ 0
- 48
tests/framework/utils/builder/graph_builder_utils.cc View File

@@ -1,48 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph_builder_utils.h"
#include "inc/external/graph/operator.h"
#include "inc/external/graph/operator_factory.h"
#include "graph/utils/graph_utils.h"

namespace ge {
namespace st {
NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format, DataType data_type, std::vector<int64_t> shape) {
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape(std::move(shape)));
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);

auto op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < in_cnt; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (int i = 0; i < out_cnt; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}

return graph_->AddNode(op_desc);
}
void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) {
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
}
void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) {
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
}
} // namespace st
} // namespace ge

+ 0
- 55
tests/framework/utils/builder/graph_builder_utils.h View File

@@ -1,55 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H
#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H

#include <string>
#include <vector>

#include "graph/compute_graph.h"
#include "graph/utils/graph_utils.h"
#include "graph/graph.h"
#include "graph/node.h"

namespace ge {
namespace st {
class ComputeGraphBuilder {
public:
explicit ComputeGraphBuilder(const std::string &name) {
graph_ = std::make_shared<ComputeGraph>(name);
}
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT,
std::vector<int64_t> shape = {1, 1, 224, 224});
void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx);
void AddControlEdge(NodePtr &src_node, NodePtr &dst_node);
ComputeGraphPtr GetComputeGraph() {
graph_->TopologicalSorting();
return graph_;
}
Graph GetGraph() {
graph_->TopologicalSorting();
return GraphUtils::CreateGraphFromComputeGraph(graph_);
}

private:
ComputeGraphPtr graph_;
};
} // namespace st
} // namespace ge

#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H

+ 1
- 1
tests/st/testcase/CMakeLists.txt View File

@@ -8,7 +8,7 @@ target_include_directories(graph_engine_test


set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17)


target_link_libraries(graph_engine_test PRIVATE gtest framework)
target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env)


include(CTest) include(CTest)
enable_testing() enable_testing()

+ 49
- 78
tests/st/testcase/test_framework_dummy.cc View File

@@ -15,23 +15,12 @@
*/ */


#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <map>
#include "external/ge/ge_api.h" #include "external/ge/ge_api.h"
#include "ge_running_env/fake_engine.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "framework/common/types.h" #include "framework/common/types.h"

#include "builder/graph_builder_utils.h"
#include "ge_running_env/ge_running_env_faker.h" #include "ge_running_env/ge_running_env_faker.h"

#include "graph/operator_reg.h"
#include "graph/operator.h"
#define protected public
#define private public
#include "graph/utils/op_desc_utils.h"
#include "ge_graph_dsl/graph_dsl.h" #include "ge_graph_dsl/graph_dsl.h"
#undef protected
#undef private
#include "ge_graph_dsl/assert/graph_assert.h"


using namespace std; using namespace std;
using namespace ge; using namespace ge;
@@ -57,76 +46,58 @@ namespace {
* *
**/ **/
Graph BuildV1ControlFlowGraph() { Graph BuildV1ControlFlowGraph() {
// build graph
st::ComputeGraphBuilder graphBuilder("g1");
auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1);
auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1);
ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1");
auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1);
auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1);
auto less = graphBuilder.AddNode("less", LESS, 2, 1);
auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL);
auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2);
auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1);
auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1);
auto add = graphBuilder.AddNode("add", ADD, 2, 1);
auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1);

auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1);
auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1);
ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1");
auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1);
auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2);
auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1);
auto mul = graphBuilder.AddNode("mul", MUL, 2, 1);
auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1);
auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1);
auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2);
// i = i+1
graphBuilder.AddDataEdge(data_i, 0, enter_i, 0);
graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0);
graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1);
graphBuilder.AddDataEdge(merge_i, 0, less, 0);
graphBuilder.AddDataEdge(const_5, 0, less, 1);
graphBuilder.AddDataEdge(less, 0, loopcond, 0);
graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1);
graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0);
graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0);
graphBuilder.AddDataEdge(switch_i, 1, add, 0);
graphBuilder.AddDataEdge(const_1, 0, add, 1);
graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0);
graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1);
// a=a*2
graphBuilder.AddDataEdge(data_a, 0, enter_a, 0);
graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0);
graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1);
graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1);
graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0);
graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0);
graphBuilder.AddDataEdge(switch_a, 1, mul, 0);
graphBuilder.AddDataEdge(const_2, 0, mul, 1);
graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0);
graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0);
// set const weight
int64_t dims_size = 1; int64_t dims_size = 1;
vector<int64_t> data_vec = {5}; vector<int64_t> data_vec = {5};
for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; });
vector<int32_t> data_value_vec(dims_size, 1); vector<int32_t> data_value_vec(dims_size, 1);
GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32);
GeTensorPtr data_tensor =
make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t));
OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor);
OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor);
OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor);
GeTensorPtr data_tensor = make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(),
data_value_vec.size() * sizeof(int32_t));


return graphBuilder.GetGraph();
auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1");
auto const_op = OP_CFG(CONSTANT).Weight(data_tensor);

DEF_GRAPH(g1) {
CHAIN(NODE("data_i", DATA)
->NODE("enter_i", enter)
->EDGE(0, 0)
->NODE("merge_i", MERGE)
->NODE("less", LESS)
->NODE("loopcond", LOOPCOND));
CHAIN(NODE("const_1", const_op)
->EDGE(0, 1)
->NODE("add", ADD)
->NODE("iteration_i", NEXTITERATION)
->EDGE(0, 1)
->NODE("merge_i"));
CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less"));
CHAIN(NODE("loopcond")
->EDGE(0, 1)
->NODE("switch_i", SWITCH)
->EDGE(0, 0)
->NODE("exit_i", EXIT)
->EDGE(0, 1)
->NODE("netoutput", NETOUTPUT));
CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add"));
CHAIN(NODE("data_a", DATA)
->NODE("enter_a", enter)
->NODE("merge_a", MERGE)
->NODE("switch_a", SWITCH)
->NODE("exit_a", EXIT)
->EDGE(0, 0)
->NODE("netoutput"));
CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a"));
CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL));
CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a"));
};
return ToGeGraph(g1);
} }
} // namespace } // namespace
class FrameworkTest : public testing::Test { class FrameworkTest : public testing::Test {
protected: protected:
GeRunningEnvFaker ge_env;
void SetUp() { ge_env.InstallDefault(); } void SetUp() { ge_env.InstallDefault(); }
void TearDown() {} void TearDown() {}
GeRunningEnvFaker ge_env;
}; };


/// data data /// data data
@@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) {
DEF_GRAPH(g1) { DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data2", DATA)->NODE("add")); CHAIN(NODE("data2", DATA)->NODE("add"));
});
};


auto graph = ToGeGraph(g1);
// new session & add graph
map<AscendString, AscendString> options; map<AscendString, AscendString> options;
Session session(options); Session session(options);
auto ret = session.AddGraph(1, graph, options);
EXPECT_EQ(ret, SUCCESS);
// build input tensor
session.AddGraph(1, ToGeGraph(g1), options);
std::vector<InputTensorInfo> inputs; std::vector<InputTensorInfo> inputs;
// build_graph through session
ret = session.BuildGraph(1, inputs);
auto ret = session.BuildGraph(1, inputs);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
CHECK_GRAPH(PreRunAfterBuild) {
ASSERT_EQ(graph->GetName(), "g1_1");
ASSERT_EQ(graph->GetAllNodesSize(), 4);
};
} }


/** data a = 2; /** data a = 2;


+ 4
- 16
tests/st/testcase/test_ge_opt_info.cc View File

@@ -15,24 +15,12 @@
*/ */


#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "easy_graph/graph/box.h"
#include "easy_graph/graph/node.h"
#include "external/ge/ge_api.h"
#include "easy_graph/builder/graph_dsl.h" #include "easy_graph/builder/graph_dsl.h"
#include "easy_graph/builder/box_builder.h"
#include "easy_graph/layout/graph_layout.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h"
#include "graph/graph.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
#include "framework/common/types.h" #include "framework/common/types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_local_context.h"
#include "ge_graph_dsl/graph_dsl.h" #include "ge_graph_dsl/graph_dsl.h"
#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h"
#define protected public
#define private public
#include "ge_opt_info/ge_opt_info.h"
#undef private
#undef protected


namespace ge { namespace ge {
class STEST_opt_info : public testing::Test { class STEST_opt_info : public testing::Test {
@@ -52,7 +40,7 @@ TEST_F(STEST_opt_info, get_opt_info_all) {
DEF_GRAPH(g1) { DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data2", DATA)->NODE("add")); CHAIN(NODE("data2", DATA)->NODE("add"));
});
};


auto graph = ToGeGraph(g1); auto graph = ToGeGraph(g1);


@@ -95,7 +83,7 @@ TEST_F(STEST_opt_info, get_opt_info_success) {
DEF_GRAPH(g1) { DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data2", DATA)->NODE("add")); CHAIN(NODE("data2", DATA)->NODE("add"));
});
};


auto graph = ToGeGraph(g1); auto graph = ToGeGraph(g1);




+ 2
- 2
tests/st/testcase/test_main.cc View File

@@ -15,9 +15,8 @@
*/ */


#include <gtest/gtest.h> #include <gtest/gtest.h>

#include "common/debug/log.h"
#include "external/ge/ge_api.h" #include "external/ge/ge_api.h"
#include "ge_graph_dsl/assert/check_utils.h"
#include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" #include "ge_running_env/include/ge_running_env/ge_running_env_faker.h"


using namespace std; using namespace std;
@@ -31,6 +30,7 @@ int main(int argc, char **argv) {
std::cout << "ge init failed , ret code:" << init_status << endl; std::cout << "ge init failed , ret code:" << init_status << endl;
} }
GeRunningEnvFaker::BackupEnv(); GeRunningEnvFaker::BackupEnv();
CheckUtils::init();
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS(); int ret = RUN_ALL_TESTS();
return ret; return ret;


+ 1
- 0
tests/ut/common/graph/CMakeLists.txt View File

@@ -90,6 +90,7 @@ set(SRC_FILES
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc"
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc"
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc"


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -102,6 +102,7 @@ set(GRAPH_SRC_FILES
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc"
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc"
"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc"


+ 3
- 1
tests/ut/ge/graph/build/task_generator_unittest.cc View File

@@ -116,7 +116,9 @@ TEST_F(UtestTaskGeneratorTest, FindLastBpFromBpNode) {
TaskGenerator task_generator(nullptr, 0); TaskGenerator task_generator(nullptr, 0);
auto net_output = graph->FindNode("Node_Output"); auto net_output = graph->FindNode("Node_Output");
// netoutput has no data input, return default value 0 // netoutput has no data input, return default value 0
EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output), 0);
uint32_t bp_index = 0;
EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output, bp_index), 0);
EXPECT_EQ(bp_index, 2);
} }


TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) { TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) {


+ 1
- 1
tests/ut/ge/graph/passes/addn_pass_unittest.cc View File

@@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) {
AddNPass *addn_pass = nullptr; AddNPass *addn_pass = nullptr;
NamesToPass names_to_pass; NamesToPass names_to_pass;
names_to_pass.emplace_back("Test", addn_pass); names_to_pass.emplace_back("Test", addn_pass);
EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
EXPECT_NE(pass.Run(names_to_pass), SUCCESS);
} }


TEST(UtestGraphPassesAddnPass, null_graph) { TEST(UtestGraphPassesAddnPass, null_graph) {


+ 115
- 0
tests/ut/ge/graph/passes/base_pass_unittest.cc View File

@@ -67,6 +67,22 @@ class UtestTestPass : public BaseNodePass {
names_to_add_repass_.erase(iter); names_to_add_repass_.erase(iter);
} }
} }

iter = names_to_add_repass_immediate_.find(node->GetName());
if (iter != names_to_add_repass_immediate_.end()) {
auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
for (const auto &node_name : iter->second) {
for (auto &node_re_pass : all_nodes) {
if (node_re_pass->GetName() == node_name) {
AddImmediateRePassNode(node_re_pass);
break;
}
}
}
if (!dead_loop_) {
names_to_add_repass_.erase(iter);
}
}
// simulate infershape pass // simulate infershape pass
if(node->GetType() == WHILE){ if(node->GetType() == WHILE){
bool need_repass = false; bool need_repass = false;
@@ -94,12 +110,17 @@ class UtestTestPass : public BaseNodePass {
void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { void AddDelNodeName(const std::string &iter_node, const std::string &del_node) {
names_to_add_del_[iter_node].insert(del_node); names_to_add_del_[iter_node].insert(del_node);
} }
void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) {
names_to_add_repass_immediate_[iter_node].insert(re_pass_node);
}

unsigned int GetRunTimes() { return run_times_; } unsigned int GetRunTimes() { return run_times_; }


private: private:
std::vector<NodePtr> iter_nodes_; std::vector<NodePtr> iter_nodes_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_del_; std::map<std::string, std::unordered_set<std::string>> names_to_add_del_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_; std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_immediate_;
bool dead_loop_; bool dead_loop_;
unsigned int run_times_; unsigned int run_times_;
}; };
@@ -520,4 +541,98 @@ EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
} }



TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass pre_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "add1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "add1", "sum1"});
CheckIterOrder(&test_pass, layers);
}
/// sum1
/// / \.
/// / \.
/// / \.
/// reshape1 addn1
/// | c |
/// add1 <--- shape1
/// / \ |
/// | | |
/// data1 const1 const2
TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass cur_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "reshape1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass next_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "sum1");
// repass node after next_node immediately
test_pass.AddRePassImmediateNodeName("add1", "sum1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}
/*
TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass next_node immediately
test_pass.AddRePassNodeName("reshape1", "sum1");
// repass node after next_node immediately
test_pass.AddRePassNodeName("add1", "sum1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}*/
} // namespace ge } // namespace ge

+ 45
- 0
tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc View File

@@ -293,6 +293,9 @@ class AddKernel : public Kernel {
} else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) { } else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) {
vector<int32_t> data_vec; vector<int32_t> data_vec;
auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize();
if (input[0]->GetTensorDesc().GetShape().IsScalar()) {
data_num = 1;
}
auto x1_data = reinterpret_cast<const int32_t *>(input[0]->GetData().data()); auto x1_data = reinterpret_cast<const int32_t *>(input[0]->GetData().data());
auto x2_data = reinterpret_cast<const int32_t *>(input[1]->GetData().data()); auto x2_data = reinterpret_cast<const int32_t *>(input[1]->GetData().data());
for (size_t i = 0; i < data_num; i++) { for (size_t i = 0; i < data_num; i++) {
@@ -410,6 +413,48 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave
EXPECT_EQ(unknown_target_value_range, output_value_range); EXPECT_EQ(unknown_target_value_range, output_value_range);
} }


TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_ScalarOutput) {
// shape --- add --- sqrt
// constant /
auto graph = std::make_shared<ComputeGraph>("test_graph");
vector<int32_t> data_vec = {2};
GeTensorDesc const_td(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
GeTensorPtr const_tensor = std::make_shared<ge::GeTensor>(const_td, (uint8_t *)data_vec.data(), sizeof(int32_t));
auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant");
const_op_desc->AddOutputDesc(const_td);
EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS);
auto const_node = graph->AddNode(const_op_desc);

GeTensorDesc shape_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
std::vector<std::pair<int64_t, int64_t>> known_value_range = {make_pair(1, 100)};
shape_td.SetValueRange(known_value_range);
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape");
shape_op_desc->AddOutputDesc(shape_td);
auto shape_node = graph->AddNode(shape_op_desc);

GeTensorDesc add_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add");
add_op_desc->AddInputDesc(shape_td);
add_op_desc->AddInputDesc(const_td);
add_op_desc->AddOutputDesc(add_td);
auto add_node = graph->AddNode(add_op_desc);

ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1));

InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS);

auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> out_value_range;
output_0_desc.GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 1);

std::vector<int64_t> target_value_range = {3, 102};
std::vector<int64_t> output_value_range = {out_value_range[0].first, out_value_range[0].second};
EXPECT_EQ(output_value_range, target_value_range);
}

TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) {
// shape --- add --- sqrt // shape --- add --- sqrt
// constant / // constant /


+ 28
- 0
tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc View File

@@ -103,4 +103,32 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute) {
context.callback_manager->callback_queue_.Push(eof_entry); context.callback_manager->callback_queue_.Push(eof_entry);
ASSERT_EQ(executor.Execute(args), SUCCESS); ASSERT_EQ(executor.Execute(args), SUCCESS);
} }

TEST_F(UtestHybridModelAsyncExecutor, test_PrepareInputs) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
GeModelPtr ge_sub_model = make_shared<GeModel>();
HybridModel hybrid_model(ge_root_model);
HybridModelAsyncExecutor executor(&hybrid_model);
GeTensorDescPtr tensor_desc = make_shared<GeTensorDesc>(GeShape({-1, 16, 16, 3}));
tensor_desc->SetShapeRange({{1, 256}, {16, 16}, {16, 16}, {3, 3}});
executor.input_tensor_desc_.insert({0, tensor_desc});
executor.device_id_ = 0;
executor.input_sizes_.insert({0, -1});
executor.is_input_dynamic_.push_back(true);

unique_ptr<uint8_t[]> data_buf(new (std::nothrow)uint8_t[3072]);
InputData input_data;
input_data.blobs.push_back(DataBuffer(data_buf.get(), 3072, false));
input_data.shapes.push_back({1, 16, 16, 3});
HybridModelExecutor::ExecuteArgs args;

auto ret = executor.PrepareInputs(input_data, args);
ASSERT_EQ(ret, SUCCESS);
ASSERT_EQ(args.input_desc[0]->GetShape().ToString(), GeShape({1, 16, 16, 3}).ToString());
int64_t tensor_size = 0;
TensorUtils::GetSize(*(args.input_desc[0]), tensor_size);
ASSERT_EQ(tensor_size, 3104);
}
} // namespace ge } // namespace ge

+ 12
- 0
tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc View File

@@ -27,6 +27,7 @@
#include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/executor/hybrid_model_executor.h"
#include "hybrid/executor/worker/execution_engine.h" #include "hybrid/executor/worker/execution_engine.h"
#include "hybrid/executor/subgraph_executor.h" #include "hybrid/executor/subgraph_executor.h"
#include "hybrid/executor/worker/task_compile_engine.h"
#undef private #undef private
#undef protected #undef protected


@@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test {
}; };
namespace { namespace {
const int kIntBase = 10; const int kIntBase = 10;
class CompileNodeExecutor : public NodeExecutor {
public:
Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr<NodeTask> &task) const override {
return SUCCESS;
}
};
} }

static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") {
auto op_desc = std::make_shared<ge::OpDesc>(name, type); auto op_desc = std::make_shared<ge::OpDesc>(name, type);
op_desc->SetStreamId(0); op_desc->SetStreamId(0);
@@ -128,4 +136,8 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
executor.InitCallback(node_state.get(), callback); executor.InitCallback(node_state.get(), callback);
ExecutionEngine execution_engine; ExecutionEngine execution_engine;
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
CompileNodeExecutor node_executor;
node_item->node_executor = &node_executor;
EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS);
} }

+ 1
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -153,6 +153,7 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) {
ge::AttrUtils::SetStr(op_desc, "compile_info_json", "json"); ge::AttrUtils::SetStr(op_desc, "compile_info_json", "json");
ge::AttrUtils::SetBool(op_desc, "support_dynamicshape", true); ge::AttrUtils::SetBool(op_desc, "support_dynamicshape", true);
ge::AttrUtils::SetInt(op_desc, "op_para_size", 1); ge::AttrUtils::SetInt(op_desc, "op_para_size", 1);
ge::AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF");
auto node = graph->AddNode(op_desc); auto node = graph->AddNode(op_desc);


std::unique_ptr<NodeItem> node_item; std::unique_ptr<NodeItem> node_item;


+ 1
- 0
tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc View File

@@ -87,6 +87,7 @@ TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) {
TEST_F(NodeExecutorTest, TestInitAndFinalize) { TEST_F(NodeExecutorTest, TestInitAndFinalize) {
auto &manager = NodeExecutorManager::GetInstance(); auto &manager = NodeExecutorManager::GetInstance();
manager.FinalizeExecutors(); manager.FinalizeExecutors();
manager.FinalizeExecutors();
manager.EnsureInitialized(); manager.EnsureInitialized();
manager.EnsureInitialized(); manager.EnsureInitialized();
const NodeExecutor *executor = nullptr; const NodeExecutor *executor = nullptr;


+ 20
- 1
tests/ut/ge/single_op/single_op_model_unittest.cc View File

@@ -311,7 +311,7 @@ TEST_F(UtestSingleOpModel, BuildTaskList) {
ASSERT_EQ(mem_task.LaunchKernel(0), SUCCESS); ASSERT_EQ(mem_task.LaunchKernel(0), SUCCESS);
} }


TEST_F(UtestSingleOpModel, build_aicpu_task) {
TEST_F(UtestSingleOpModel, build_dynamic_task) {
ComputeGraphPtr graph = make_shared<ComputeGraph>("single_op"); ComputeGraphPtr graph = make_shared<ComputeGraph>("single_op");
GeModelPtr ge_model = make_shared<GeModel>(); GeModelPtr ge_model = make_shared<GeModel>();
ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));
@@ -321,6 +321,15 @@ TEST_F(UtestSingleOpModel, build_aicpu_task) {
domi::TaskDef *task_def = model_task_def->add_task(); domi::TaskDef *task_def = model_task_def->add_task();
task_def->set_type(RT_MODEL_TASK_KERNEL_EX); task_def->set_type(RT_MODEL_TASK_KERNEL_EX);


domi::TaskDef *task_def2 = model_task_def->add_task();
task_def2->set_type(RT_MODEL_TASK_KERNEL);
domi::KernelDef *kernel_def = task_def2->mutable_kernel();
domi::KernelContext *context = kernel_def->mutable_context();
context->set_kernel_type(6); // ccKernelType::AI_CPU

domi::TaskDef *task_def3 = model_task_def->add_task();
task_def3->set_type(RT_MODEL_TASK_ALL_KERNEL);

string model_data_str = "123456789"; string model_data_str = "123456789";
SingleOpModel model("model", model_data_str.c_str(), model_data_str.size()); SingleOpModel model("model", model_data_str.c_str(), model_data_str.size());
std::mutex stream_mu; std::mutex stream_mu;
@@ -329,8 +338,18 @@ TEST_F(UtestSingleOpModel, build_aicpu_task) {
DynamicSingleOp single_op(0, &stream_mu, stream); DynamicSingleOp single_op(0, &stream_mu, stream);
model.model_helper_.model_ = ge_model; model.model_helper_.model_ = ge_model;
auto op_desc = std::make_shared<ge::OpDesc>("add", "Add"); auto op_desc = std::make_shared<ge::OpDesc>("add", "Add");
AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF");
std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
NodePtr node = graph->AddNode(op_desc); NodePtr node = graph->AddNode(op_desc);
model.op_list_[0] = node; model.op_list_[0] = node;
StreamResource *res = new (std::nothrow) StreamResource(1); StreamResource *res = new (std::nothrow) StreamResource(1);

ASSERT_EQ(model.ParseTasks(), SUCCESS);
ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS);
model.tbe_tasks_.clear();
ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS); ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS);
model.aicpu_tasks_[0] = *task_def2;
model.BuildTaskListForDynamicOp(res, single_op);
} }

+ 1
- 0
tests/ut/ge/single_op/single_op_task_unittest.cc View File

@@ -54,6 +54,7 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) {


auto graph = make_shared<ComputeGraph>("graph"); auto graph = make_shared<ComputeGraph>("graph");
auto op_desc = make_shared<OpDesc>("Add", "Add"); auto op_desc = make_shared<OpDesc>("Add", "Add");
AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF");
std::vector<char> kernelBin; std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin)); TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);


+ 7
- 0
third_party/fwkacllib/inc/external/runtime/rt_error_codes.h View File

@@ -38,6 +38,7 @@ static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callba
static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type
static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle
static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type
static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout
static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support
static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error
@@ -50,6 +51,7 @@ static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no eve
static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource
static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource
static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource
static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource
static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error
static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error
@@ -85,9 +87,14 @@ static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug
static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context
static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out
static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout
static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception
static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal
static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error
static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error
static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect
#ifdef __cplusplus #ifdef __cplusplus
} }


+ 4
- 4
third_party/fwkacllib/inc/runtime/base.h View File

@@ -156,7 +156,7 @@ RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtSt


/** /**
* @ingroup profiling_base * @ingroup profiling_base
* @brief ts send keypoint for step info.
* @brief ts send keypoint profiler log.
*/ */
RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream); RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream);


@@ -206,7 +206,7 @@ RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCal


/** /**
* @ingroup dvrt_base * @ingroup dvrt_base
* @brief register callback for fail task
* @brief register callback for fail task
* @param [in] uniName unique register name, can't be null * @param [in] uniName unique register name, can't be null
* @param [in] callback fail task callback function * @param [in] callback fail task callback function
* @param [out] NA * @param [out] NA
@@ -345,11 +345,11 @@ RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream);
* @return RT_ERROR_NONE for ok * @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input * @return RT_ERROR_INVALID_VALUE for error input
*/ */
rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream);
RTS_API rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream);


/** /**
* @ingroup dvrt_base * @ingroup dvrt_base
* @brief get current thread last stream id and task id
* @brief get current thread last stream id and task id
* @param [out] stream id and task id * @param [out] stream id and task id
* @param [in] null * @param [in] null
* @return RT_ERROR_NONE for ok * @return RT_ERROR_NONE for ok


+ 43
- 0
third_party/fwkacllib/inc/runtime/config.h View File

@@ -46,6 +46,12 @@ typedef enum tagRtChipType {
CHIP_END, CHIP_END,
} rtChipType_t; } rtChipType_t;


typedef enum tagRtAicpuScheType {
SCHEDULE_SOFTWARE = 0, /* Software Schedule */
SCHEDULE_SOFTWARE_OPT,
SCHEDULE_HARDWARE, /* HWTS Schedule */
} rtAicpuScheType;

typedef enum tagRtVersion { typedef enum tagRtVersion {
VER_BEGIN = 0, VER_BEGIN = 0,
VER_NA = VER_BEGIN, VER_NA = VER_BEGIN,
@@ -65,6 +71,7 @@ typedef enum tagRtPlatformType {
PLATFORM_LHISI_CS, PLATFORM_LHISI_CS,
PLATFORM_DC, PLATFORM_DC,
PLATFORM_CLOUD_V2, PLATFORM_CLOUD_V2,
PLATFORM_LHISI_SD3403,
PLATFORM_END, PLATFORM_END,
} rtPlatformType_t; } rtPlatformType_t;


@@ -126,6 +133,11 @@ typedef struct tagRtPlatformConfig {
uint32_t platformConfig; uint32_t platformConfig;
} rtPlatformConfig_t; } rtPlatformConfig_t;


typedef enum tagRTTaskTimeoutType {
RT_TIMEOUT_TYPE_OP_WAIT = 0,
RT_TIMEOUT_TYPE_OP_EXECUTE,
} rtTaskTimeoutType_t;

/** /**
* @ingroup * @ingroup
* @brief get AI core count * @brief get AI core count
@@ -184,6 +196,37 @@ RTS_API rtError_t rtMemGetL2Info(rtStream_t stream, void **ptr, uint32_t *size);
*/ */
RTS_API rtError_t rtGetRuntimeVersion(uint32_t *runtimeVersion); RTS_API rtError_t rtGetRuntimeVersion(uint32_t *runtimeVersion);



/**
* @ingroup
* @brief get device feature ability by device id, such as task schedule ability.
* @param [in] deviceId
* @param [in] moduleType
* @param [in] featureType
* @param [out] value
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtGetDeviceCapability(int32_t deviceId, int32_t moduleType, int32_t featureType, int32_t *value);

/**
* @ingroup
* @brief set event wait task timeout time.
* @param [in] timeout
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtSetOpWaitTimeOut(uint32_t timeout);

/**
* @ingroup
* @brief set op execute task timeout time.
* @param [in] timeout
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtSetOpExecuteTimeOut(uint32_t timeout);

#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) #if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE)
} }
#endif #endif


+ 5
- 0
third_party/fwkacllib/inc/runtime/dev.h View File

@@ -63,6 +63,11 @@ typedef enum tagRtFeatureType {
FEATURE_TYPE_RSV FEATURE_TYPE_RSV
} rtFeatureType_t; } rtFeatureType_t;


typedef enum tagRtDeviceFeatureType {
FEATURE_TYPE_SCHE,
FEATURE_TYPE_END,
} rtDeviceFeatureType_t;

typedef enum tagMemcpyInfo { typedef enum tagMemcpyInfo {
MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, MEMCPY_INFO_SUPPORT_ZEROCOPY = 0,
MEMCPY_INFO_RSV MEMCPY_INFO_RSV


+ 35
- 2
third_party/fwkacllib/inc/runtime/event.h View File

@@ -23,12 +23,23 @@
extern "C" { extern "C" {
#endif #endif


typedef enum rtEventWaitStatus {
EVENT_STATUS_COMPLETE = 0,
EVENT_STATUS_NOT_READY = 1,
EVENT_STATUS_MAX = 2,
} rtEventWaitStatus_t;

/** /**
* @ingroup event_flags * @ingroup event_flags
* @brief event op bit flags * @brief event op bit flags
*/ */
#define RT_EVENT_DEFAULT (0x00)
#define RT_EVENT_WITH_FLAG (0x01)
#define RT_EVENT_DEFAULT (0x0E)
#define RT_EVENT_WITH_FLAG (0x0B)

#define RT_EVENT_DDSYNC_NS 0x01U
#define RT_EVENT_STREAM_MARK 0x02U
#define RT_EVENT_DDSYNC 0x04U
#define RT_EVENT_TIME_LINE 0x08U


/** /**
* @ingroup dvrt_event * @ingroup dvrt_event
@@ -104,6 +115,16 @@ RTS_API rtError_t rtEventSynchronize(rtEvent_t event);
*/ */
RTS_API rtError_t rtEventQuery(rtEvent_t event); RTS_API rtError_t rtEventQuery(rtEvent_t event);


/**
* @ingroup dvrt_event
* @brief Queries an event's wait status
* @param [in] event event to query
* @param [in out] EVENT_WAIT_STATUS status
* @return EVENT_STATUS_COMPLETE for complete
* @return EVENT_STATUS_NOT_READY for not complete
*/
RTS_API rtError_t rtEventQueryWaitStatus(rtEvent_t event, rtEventWaitStatus_t *status);

/** /**
* @ingroup dvrt_event * @ingroup dvrt_event
* @brief computes the elapsed time between events. * @brief computes the elapsed time between events.
@@ -176,6 +197,18 @@ RTS_API rtError_t rtNotifyRecord(rtNotify_t notify, rtStream_t stream);
*/ */
RTS_API rtError_t rtNotifyWait(rtNotify_t notify, rtStream_t stream); RTS_API rtError_t rtNotifyWait(rtNotify_t notify, rtStream_t stream);


/**
* @ingroup dvrt_event
* @brief Wait for a notify with time out
* @param [in] notify_ notify to be wait
* @param [in] stream_ input stream
* @param [in] timeOut input timeOut
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
* @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx
*/
RTS_API rtError_t rtNotifyWaitWithTimeOut(rtNotify_t notify_, rtStream_t stream_, uint32_t timeOut);

/** /**
* @ingroup dvrt_event * @ingroup dvrt_event
* @brief Name a notify * @brief Name a notify


+ 66
- 17
third_party/fwkacllib/inc/runtime/kernel.h View File

@@ -111,6 +111,16 @@ typedef struct rtKernelInfo {
uint32_t module_size; uint32_t module_size;
} *rtKernelInfo_t; } *rtKernelInfo_t;


/**
* @ingroup rt_kernel
* @brief op name
*/
typedef struct rtKernelLaunchNames {
const char *soName; // defined for so name
const char *kernelName; // defined for kernel type name
const char *opName; // defined for operator name
} rtKernelLaunchNames_t;

/** /**
* @ingroup rt_KernelConfigDump * @ingroup rt_KernelConfigDump
* @brief device dump type * @brief device dump type
@@ -173,13 +183,7 @@ typedef void (*rtCallback_t)(void *fnData);
* @ingroup rt_kernel * @ingroup rt_kernel
* @brief magic number of elf binary for aicube * @brief magic number of elf binary for aicube
*/ */
#define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41415247

/**
* @ingroup rt_kernel
* @brief magic number of elf binary for aivector
*/
#define RT_DEV_BINARY_MAGIC_ELF_AIVECTOR 0x41415248
#define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41494343


/** /**
* @ingroup rt_kernel_flags * @ingroup rt_kernel_flags
@@ -192,14 +196,14 @@ typedef void (*rtCallback_t)(void *fnData);
#define RT_KERNEL_CUSTOM_AICPU (0x08) #define RT_KERNEL_CUSTOM_AICPU (0x08)


// STARS topic scheduler sqe : topic_type // STARS topic scheduler sqe : topic_type
#define RT_KERNEL_DEVICE_FIRST (0X10)
#define RT_KERNEL_HOST_ONLY (0X20)
#define RT_KERNEL_HOST_FIRST (0X30)
#define RT_KERNEL_DEVICE_FIRST (0x10)
#define RT_KERNEL_HOST_ONLY (0x20)
#define RT_KERNEL_HOST_FIRST (0x40)


/** /**
* @ingroup rt_kernel * @ingroup rt_kernel
* @brief kernel mode * @brief kernel mode
*/
**/
#define RT_DEFAULT_KERNEL_MODE (0x00) #define RT_DEFAULT_KERNEL_MODE (0x00)
#define RT_NORMAL_KERNEL_MODE (0x01) #define RT_NORMAL_KERNEL_MODE (0x01)
#define RT_ALL_KERNEL_MODE (0x02) #define RT_ALL_KERNEL_MODE (0x02)
@@ -222,7 +226,7 @@ RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **handle);


/** /**
* @ingroup rt_kernel * @ingroup rt_kernel
* @brief register device binary
* @brief register device binary with all kernel
* @param [in] bin device binary description * @param [in] bin device binary description
* @param [out] handle device binary handle * @param [out] handle device binary handle
* @return RT_ERROR_NONE for ok * @return RT_ERROR_NONE for ok
@@ -341,7 +345,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *
* @ingroup rt_kernel * @ingroup rt_kernel
* @brief launch kernel with handle to device * @brief launch kernel with handle to device
* @param [in] handle program * @param [in] handle program
* @param [in] devFunc device function description
* @param [in] devFunc device function description.
* @param [in] blockDim block dimentions * @param [in] blockDim block dimentions
* @param [in] args argments address for kernel function * @param [in] args argments address for kernel function
* @param [in] argsSize argements size * @param [in] argsSize argements size
@@ -352,7 +356,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *
* @return RT_ERROR_INVALID_VALUE for error input * @return RT_ERROR_INVALID_VALUE for error input
*/ */
RTS_API rtError_t rtKernelLaunchWithHandle(void *handle, const void *devFunc, uint32_t blockDim, void *args, uint32_t argsSize, RTS_API rtError_t rtKernelLaunchWithHandle(void *handle, const void *devFunc, uint32_t blockDim, void *args, uint32_t argsSize,
rtSmDesc_t *smDesc, rtStream_t stream, const void *kernelInfo);
rtSmDesc_t *smDesc, rtStream_t stream_, const void *kernelInfo);


/** /**
* @ingroup rt_kernel * @ingroup rt_kernel
@@ -371,7 +375,7 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim
rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags);


/** /**
* @ingroup rt_kernel
* @ingroup rt_kernel(abandoned)
* @brief launch kernel to device * @brief launch kernel to device
* @param [in] args argments address for kernel function * @param [in] args argments address for kernel function
* @param [in] argsSize argements size * @param [in] argsSize argements size
@@ -383,7 +387,21 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim
RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream); RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream);


/** /**
* @ingroup rt_kernel
* @ingroup rt_kernel(in use)
* @brief launch kernel to device
* @param [in] opName opkernel name
* @param [in] args argments address for kernel function
* @param [in] argsSize argements size
* @param [in] flags launch flags
* @param [in] stream associated stream
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtKernelLaunchFwk(const char *opName, void *args, uint32_t argsSize, uint32_t flags,
rtStream_t rtStream);

/**
* @ingroup rt_kernel(abandoned)
* @brief launch cpu kernel to device * @brief launch cpu kernel to device
* @param [in] soName so name * @param [in] soName so name
* @param [in] kernelName kernel name * @param [in] kernelName kernel name
@@ -399,7 +417,22 @@ RTS_API rtError_t rtCpuKernelLaunch(const void *soName, const void *kernelName,
uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream);


/** /**
* @ingroup rt_kernel
* @ingroup rt_kernel(in use)
* @brief launch cpu kernel to device
* @param [in] launchNames names for kernel launch
* @param [in] blockDim block dimentions
* @param [in] args argments address for kernel function
* @param [in] argsSize argments size
* @param [in] smDesc shared memory description
* @param [in] stream associated stream
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtAicpuKernelLaunch(const rtKernelLaunchNames_t *launchNames,
uint32_t blockDim, const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream);

/**
* @ingroup rt_kernel(abandoned)
* @brief launch cpu kernel to device with dump identifier * @brief launch cpu kernel to device with dump identifier
* @param [in] soName so name * @param [in] soName so name
* @param [in] kernelName kernel name * @param [in] kernelName kernel name
@@ -416,6 +449,22 @@ RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kern
const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream,
uint32_t flags); uint32_t flags);


/**
* @ingroup rt_kernel(in use)
* @brief launch cpu kernel to device with dump identifier
* @param [in] launchNames names for kernel launch
* @param [in] blockDim block dimentions
* @param [in] args argments address for kernel function
* @param [in] argsSize argments size
* @param [in] smDesc shared memory description
* @param [in] stream associated stream
* @param [in] flag dump flag or others function flag
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtAicpuKernelLaunchWithFlag(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim,
const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags);

/** /**
* @ingroup rt_kernel * @ingroup rt_kernel
* @brief L1 fusion dump addr transfered to device * @brief L1 fusion dump addr transfered to device


+ 11
- 0
third_party/fwkacllib/inc/runtime/mem.h View File

@@ -116,6 +116,9 @@ typedef enum tagRtMemInfoType {


typedef enum tagRtRecudeKind { typedef enum tagRtRecudeKind {
RT_MEMCPY_SDMA_AUTOMATIC_ADD = 10, // D2D, SDMA inline reduce, include 1P, and P2P RT_MEMCPY_SDMA_AUTOMATIC_ADD = 10, // D2D, SDMA inline reduce, include 1P, and P2P
RT_MEMCPY_SDMA_AUTOMATIC_MAX = 11,
RT_MEMCPY_SDMA_AUTOMATIC_MIN = 12,
RT_MEMCPY_SDMA_AUTOMATIC_EQUAL = 13,
RT_RECUDE_KIND_END RT_RECUDE_KIND_END
} rtRecudeKind_t; } rtRecudeKind_t;


@@ -123,6 +126,14 @@ typedef enum tagRtDataType {
RT_DATA_TYPE_FP32 = 0, // fp32 RT_DATA_TYPE_FP32 = 0, // fp32
RT_DATA_TYPE_FP16 = 1, // fp16 RT_DATA_TYPE_FP16 = 1, // fp16
RT_DATA_TYPE_INT16 = 2, // int16 RT_DATA_TYPE_INT16 = 2, // int16
RT_DATA_TYPE_INT4 = 3, // int4
RT_DATA_TYPE_INT8 = 4, // int8
RT_DATA_TYPE_INT32 = 5, // int32
RT_DATA_TYPE_BFP16 = 6, // bfp16
RT_DATA_TYPE_BFP32 = 7, // bfp32
RT_DATA_TYPE_UINT8 = 8, // uint8
RT_DATA_TYPE_UINT16= 9, // uint16
RT_DATA_TYPE_UINT32= 10,// uint32
RT_DATA_TYPE_END RT_DATA_TYPE_END
} rtDataType_t; } rtDataType_t;




+ 11
- 2
third_party/fwkacllib/inc/runtime/rt_model.h View File

@@ -135,12 +135,13 @@ typedef struct tagAllKernelTaskInfo {
uint16_t argsCount; uint16_t argsCount;
uint16_t argsSize; uint16_t argsSize;
uint16_t reserved; uint16_t reserved;
const void *dev_func;
void *devfunc;
void *handle; void *handle;
uint8_t *smDesc; uint8_t *smDesc;
uint8_t *args; uint8_t *args;
uint16_t *argsOffset; uint16_t *argsOffset;
} rtAllKernelTaskInfo_t; } rtAllKernelTaskInfo_t;

typedef struct tagKernelTaskInfoEx { typedef struct tagKernelTaskInfoEx {
uint32_t flags; uint32_t flags;
uint32_t argsSize; uint32_t argsSize;
@@ -198,6 +199,13 @@ typedef struct tagProfilerTraceTaskInfo {
uint32_t reserved[6]; uint32_t reserved[6];
} rtProfilerTrace_t; } rtProfilerTrace_t;


typedef struct tagProfilerTraceExTaskInfo {
uint64_t profilerTraceId;
uint64_t modelId;
uint16_t tagId;
uint8_t reserved[22];
} rtProfilerTraceEx_t;

typedef struct tagrtMemcpyAsyncTaskInfo { typedef struct tagrtMemcpyAsyncTaskInfo {
void *dst; void *dst;
uint64_t destMax; uint64_t destMax;
@@ -265,7 +273,7 @@ typedef struct tagTaskInfo {
union { union {
rtKernelTaskInfoEx_t kernelTaskEx; rtKernelTaskInfoEx_t kernelTaskEx;
rtKernelTaskInfo_t kernelTask; rtKernelTaskInfo_t kernelTask;
rtAllKernelTaskInfo_t allkernelTask;
rtAllKernelTaskInfo_t allKernelTask;
rtEventTaskInfo_t eventTask; rtEventTaskInfo_t eventTask;
rtStreamSwitchTaskInfo_t streamSwitchTask; rtStreamSwitchTaskInfo_t streamSwitchTask;
rtStreamActiveTaskInfo_t streamActiveTask; rtStreamActiveTaskInfo_t streamActiveTask;
@@ -273,6 +281,7 @@ typedef struct tagTaskInfo {
rtLabelSwitchTaskInfo_t labelSwitchTask; rtLabelSwitchTaskInfo_t labelSwitchTask;
rtLabelGotoTaskInfo_t labelGotoTask; rtLabelGotoTaskInfo_t labelGotoTask;
rtProfilerTrace_t profilertraceTask; rtProfilerTrace_t profilertraceTask;
rtProfilerTraceEx_t profilertraceExTask;
rtMemcpyAsyncTaskInfo_t memcpyAsyncTask; rtMemcpyAsyncTaskInfo_t memcpyAsyncTask;
rtNotifyTaskInfo_t notifyTask; rtNotifyTaskInfo_t notifyTask;
rtReduceAsyncTaskInfo_t reduceAsyncTask; rtReduceAsyncTaskInfo_t reduceAsyncTask;


+ 30
- 1
third_party/fwkacllib/inc/toolchain/prof_callback.h View File

@@ -108,7 +108,19 @@ enum MsprofCtrlCallbackType {
MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env
MSPROF_CTRL_INIT_ACL_JSON, // start profiling with acl.json MSPROF_CTRL_INIT_ACL_JSON, // start profiling with acl.json
MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options
MSPROF_CTRL_FINALIZE // stop profiling
MSPROF_CTRL_FINALIZE, // stop profiling
MSPROF_CTRL_REPORT_FUN_P, // for report callback
MSPROF_CTRL_PROF_SWITCH_ON, // for prof switch on
MSPROF_CTRL_PROF_SWITCH_OFF // for prof switch off
};

#define MSPROF_MAX_DEV_NUM (64)

struct MsprofCommandHandle {
uint64_t profSwitch;
uint32_t devNums; // length of device id list
uint32_t devIdList[MSPROF_MAX_DEV_NUM];
uint32_t modelId;
}; };


/** /**
@@ -129,6 +141,23 @@ typedef int32_t (*MsprofCtrlCallback)(uint32_t type, void *data, uint32_t len);
*/ */
typedef void (*MsprofSetDeviceCallback)(uint32_t devId, bool isOpenDevice); typedef void (*MsprofSetDeviceCallback)(uint32_t devId, bool isOpenDevice);


/*
* @name MsprofInit
* @brief Profiling module init
* @param [in] dataType: profiling type: ACL Env/ACL Json/GE Option
* @param [in] data: profiling switch data
* @param [in] dataLen: Length of data
* @return 0:SUCCESS, >0:FAILED
*/
int32_t MsprofInit(uint32_t dataType, void *data, uint32_t dataLen);

/*
* @name AscendCL
* @brief Finishing Profiling
* @param NULL
* @return 0:SUCCESS, >0:FAILED
*/
int32_t MsprofFinalize();
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif


+ 32
- 30
third_party/fwkacllib/inc/toolchain/slog.h View File

@@ -17,6 +17,8 @@
#ifndef D_SYSLOG_H_ #ifndef D_SYSLOG_H_
#define D_SYSLOG_H_ #define D_SYSLOG_H_


static const int TMP_LOG = 0;

#ifdef __cplusplus #ifdef __cplusplus
#ifndef LOG_CPP #ifndef LOG_CPP
extern "C" { extern "C" {
@@ -120,15 +122,15 @@ typedef struct tagKV {
} KeyValue; } KeyValue;


typedef enum { typedef enum {
APPLICATION = 0,
SYSTEM
APPLICATION = 0,
SYSTEM
} ProcessType; } ProcessType;


typedef struct { typedef struct {
ProcessType type;
unsigned int pid;
unsigned int deviceId;
char reserved[RESERVERD_LENGTH];
ProcessType type;
unsigned int pid;
unsigned int deviceId;
char reserved[RESERVERD_LENGTH];
} LogAttr; } LogAttr;


/** /**
@@ -141,7 +143,7 @@ enum {
IDEDD, /**< IDE daemon device */ IDEDD, /**< IDE daemon device */
IDEDH, /**< IDE daemon host */ IDEDH, /**< IDE daemon host */
HCCL, /**< HCCL */ HCCL, /**< HCCL */
FMK, /**< Framework */
FMK, /**< Adapter */
HIAIENGINE, /**< Matrix */ HIAIENGINE, /**< Matrix */
DVPP, /**< DVPP */ DVPP, /**< DVPP */
RUNTIME, /**< Runtime */ RUNTIME, /**< Runtime */
@@ -162,11 +164,11 @@ enum {
MDCDEFAULT, /**< MDC undefine */ MDCDEFAULT, /**< MDC undefine */
MDCSC, /**< MDC spatial cognition */ MDCSC, /**< MDC spatial cognition */
MDCPNC, MDCPNC,
MLL,
MLL, /**< abandon */
DEVMM, /**< Dlog memory managent */ DEVMM, /**< Dlog memory managent */
KERNEL, /**< Kernel */ KERNEL, /**< Kernel */
LIBMEDIA, /**< Libmedia */ LIBMEDIA, /**< Libmedia */
CCECPU, /**< ai cpu */
CCECPU, /**< aicpu shedule */
ASCENDDK, /**< AscendDK */ ASCENDDK, /**< AscendDK */
ROS, /**< ROS */ ROS, /**< ROS */
HCCP, HCCP,
@@ -179,7 +181,7 @@ enum {
TSDUMP, /**< TSDUMP module */ TSDUMP, /**< TSDUMP module */
AICPU, /**< AICPU module */ AICPU, /**< AICPU module */
LP, /**< LP module */ LP, /**< LP module */
TDT,
TDT, /**< tsdaemon or aicpu shedule */
FE, FE,
MD, MD,
MB, MB,
@@ -261,7 +263,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr);
#define dlog_error(moduleId, fmt, ...) \ #define dlog_error(moduleId, fmt, ...) \
do { \ do { \
DlogErrorInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogErrorInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -276,7 +278,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr);
if(CheckLogLevel(moduleId, DLOG_WARN) == 1) { \ if(CheckLogLevel(moduleId, DLOG_WARN) == 1) { \
DlogWarnInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogWarnInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -291,7 +293,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr);
if(CheckLogLevel(moduleId, DLOG_INFO) == 1) { \ if(CheckLogLevel(moduleId, DLOG_INFO) == 1) { \
DlogInfoInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogInfoInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -306,7 +308,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr);
if(CheckLogLevel(moduleId, DLOG_DEBUG) == 1) { \ if(CheckLogLevel(moduleId, DLOG_DEBUG) == 1) { \
DlogDebugInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogDebugInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -318,7 +320,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr);
#define dlog_event(moduleId, fmt, ...) \ #define dlog_event(moduleId, fmt, ...) \
do { \ do { \
DlogEventInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogEventInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -334,7 +336,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr);
if(CheckLogLevel(moduleId, level) == 1) { \ if(CheckLogLevel(moduleId, level) == 1) { \
DlogInner(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogInner(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -351,7 +353,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr);
if(CheckLogLevel(moduleId, level) == 1) { \ if(CheckLogLevel(moduleId, level) == 1) { \
DlogInner(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ DlogInner(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -369,7 +371,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr);
if(CheckLogLevel(moduleId, level) == 1) { \ if(CheckLogLevel(moduleId, level) == 1) { \
DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -381,13 +383,13 @@ DLL_EXPORT void DlogFlush(void);
* @ingroup slog * @ingroup slog
* @brief Internal log interface, other modules are not allowed to call this interface * @brief Internal log interface, other modules are not allowed to call this interface
*/ */
void DlogErrorInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3)));
void DlogWarnInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3)));
void DlogInfoInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3)));
void DlogDebugInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3)));
void DlogEventInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3)));
void DlogInner(int moduleId, int level, const char *fmt, ...) __attribute__((format(printf, 3, 4)));
void DlogWithKVInner(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...) __attribute__((format(printf, 5, 6)));
void DlogErrorInner(int moduleId, const char *fmt, ...);
void DlogWarnInner(int moduleId, const char *fmt, ...);
void DlogInfoInner(int moduleId, const char *fmt, ...);
void DlogDebugInner(int moduleId, const char *fmt, ...);
void DlogEventInner(int moduleId, const char *fmt, ...);
void DlogInner(int moduleId, int level, const char *fmt, ...);
void DlogWithKVInner(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...);


#ifdef __cplusplus #ifdef __cplusplus
#ifndef LOG_CPP #ifndef LOG_CPP
@@ -453,7 +455,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr);
if(CheckLogLevelForC(moduleId, level) == 1) { \ if(CheckLogLevelForC(moduleId, level) == 1) { \
DlogInnerForC(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogInnerForC(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -470,7 +472,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr);
if(CheckLogLevelForC(moduleId, level) == 1) { \ if(CheckLogLevelForC(moduleId, level) == 1) { \
DlogInnerForC(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ DlogInnerForC(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -488,7 +490,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr);
if(CheckLogLevelForC(moduleId, level) == 1) { \ if(CheckLogLevelForC(moduleId, level) == 1) { \
DlogWithKVInnerForC(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ DlogWithKVInnerForC(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \
} \ } \
} while (0)
} while (TMP_LOG != 0)


/** /**
* @ingroup slog * @ingroup slog
@@ -500,8 +502,8 @@ DLL_EXPORT void DlogFlushForC(void);
* @ingroup slog * @ingroup slog
* @brief Internal log interface, other modules are not allowed to call this interface * @brief Internal log interface, other modules are not allowed to call this interface
*/ */
void DlogInnerForC(int moduleId, int level, const char *fmt, ...) __attribute__((format(printf, 3, 4)));
void DlogWithKVInnerForC(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...) __attribute__((format(printf, 5, 6)));
void DlogInnerForC(int moduleId, int level, const char *fmt, ...);
void DlogWithKVInnerForC(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...);


#ifdef __cplusplus #ifdef __cplusplus
} }


+ 88
- 72
third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h View File

@@ -1,72 +1,88 @@
/**
* @file tune_api.h
*
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.\n
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n
* 描述:mstune调优接口头文件
*/
/** @defgroup mstune mstune调优接口 */
#ifndef TUNE_API_H
#define TUNE_API_H
#include <vector>
#include <map>
#include <string>
#include "graph/graph.h"
#include "ge/ge_api.h"
/**
* @ingroup mstune
*
* mstune status
*/
enum MsTuneStatus {
MSTUNE_SUCCESS, /** tune success */
MSTUNE_FAILED, /** tune failed */
};
// Option key: for train options sets
const std::string MSTUNE_SELF_KEY = "mstune";
const std::string MSTUNE_GEINIT_KEY = "initialize";
const std::string MSTUNE_GESESS_KEY = "session";
/**
* @ingroup mstune
* @par 描述: 命令行调优
*
* @attention 无
* @param option [IN] 调优参数
* @param msg [OUT] 调优异常下返回信息
* @retval #MSTUNE_SUCCESS 执行成功
* @retval #MSTUNE_FAILED 执行失败
* @par 依赖:
* @li tune_api.cpp:该接口所属的开发包。
* @li tune_api.h:该接口声明所在的头文件。
* @see 无
* @since
*/
MsTuneStatus MsTuning(const std::map<std::string, std::string> &option, std::string &msg);
/**
* @ingroup mstune
* @par 描述: 梯度调优
*
* @attention 无
* @param tuningGraph [IN] 调优图
* @param dependGraph [IN] 调优依赖图
* @param session [IN] ge连接会话
* @param option [IN] 参数集. 包含调优参数及ge参数
* @retval #MSTUNE_SUCCESS 执行成功
* @retval #MSTUNE_FAILED 执行失败
* @par 依赖:
* @li tune_api.cpp:该接口所属的开发包。
* @li tune_api.h:该接口声明所在的头文件。
* @see 无
* @since
*/
extern "C" MsTuneStatus MsTrainTuning(ge::Graph &tuningGraph, std::vector<ge::Graph> &dependGraph,
ge::Session *session, const std::map<std::string, std::map<std::string, std::string>> &option);
#endif
/**
* @file tune_api.h
*
* Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.\n
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n
* 描述:aoe调优接口头文件
*/
/** @defgroup aoe aoe调优接口 */
#ifndef TUNE_API_H
#define TUNE_API_H
#include <map>
#include <string>
#include "ge/ge_api.h"
#include "aoe_types.h"

/**
* @ingroup aoe
* @par 描述: 命令行调优
*
* @attention 无
* @param option [IN] 调优参数
* @param msg [OUT] 调优异常下返回信息
* @retval #AOE_SUCCESS 执行成功
* @retval #AOE_FAILURE 执行失败
* @par 依赖:
* @li tune_api.cpp:该接口所属的开发包。
* @li tune_api.h:该接口声明所在的头文件。
* @see 无
* @since
*/
AoeStatus AoeOfflineTuning(const std::map<std::string, std::string> &option, std::string &msg);

/**
* @ingroup aoe
* @par 描述: 调优初始化
*
* @attention 无
* @param session [IN] ge连接会话
* @param option [IN] 参数集. 包含调优参数及ge参数
* @retval #AOE_SUCCESS 执行成功
* @retval #AOE_FAILURE 执行失败
* @par 依赖:
* @li tune_api.cpp:该接口所属的开发包。
* @li tune_api.h:该接口声明所在的头文件。
* @see 无
* @since
*/
extern "C" AoeStatus AoeOnlineInitialize(ge::Session *session, const std::map<std::string, std::string> &option);

/**
* @ingroup aoe
* @par 描述: 调优去初始化
*
* @attention 无
* @param 无
* @retval #AOE_SUCCESS 执行成功
* @retval #AOE_FAILURE 执行失败
* @par 依赖:
* @li tune_api.cpp:该接口所属的开发包。
* @li tune_api.h:该接口声明所在的头文件。
* @see 无
* @since
*/
extern "C" AoeStatus AoeOnlineFinalize();

/**
* @ingroup aoe
* @par 描述: 调优处理
*
* @attention 无
* @param tuningGraph [IN] 调优图
* @param dependGraph [IN] 调优依赖图
* @param session [IN] ge连接会话
* @param option [IN] 参数集. 包含调优参数及ge参数
* @retval #AOE_SUCCESS 执行成功
* @retval #AOE_FAILURE 执行失败
* @par 依赖:
* @li tune_api.cpp:该接口所属的开发包。
* @li tune_api.h:该接口声明所在的头文件。
* @see 无
* @since
*/
extern "C" AoeStatus AoeOnlineTuning(ge::Graph &tuningGraph, std::vector<ge::Graph> &dependGraph,
ge::Session *session, const std::map<std::string, std::string> &option);
#endif

Loading…
Cancel
Save