Browse Source

Merge branch 'master' of gitee.com:mindspore/graphengine

pull/1907/head
zhaoxinxin 4 years ago
parent
commit
bb44fe1209
100 changed files with 2103 additions and 759 deletions
  1. +1
    -2
      .clang-format
  2. +2
    -0
      ge/common/CMakeLists.txt
  3. +1
    -1
      ge/common/dump/dump_properties.cc
  4. +1
    -0
      ge/common/dump/exception_dumper.cc
  5. +4
    -0
      ge/executor/CMakeLists.txt
  6. +1
    -1
      ge/ge_runtime/task/label_goto_task.cc
  7. +1
    -2
      ge/graph/build/model_builder.cc
  8. +20
    -24
      ge/graph/build/task_generator.cc
  9. +1
    -1
      ge/graph/build/task_generator.h
  10. +2
    -0
      ge/graph/load/model_manager/task_info/kernel_task_info.cc
  11. +0
    -49
      ge/graph/manager/graph_var_manager.cc
  12. +0
    -15
      ge/graph/manager/graph_var_manager.h
  13. +0
    -66
      ge/graph/manager/trans_var_data_utils.cc
  14. +0
    -11
      ge/graph/manager/trans_var_data_utils.h
  15. +138
    -91
      ge/graph/passes/base_pass.cc
  16. +1
    -1
      ge/graph/passes/folding_pass.cc
  17. +5
    -2
      ge/graph/passes/infer_base_pass.cc
  18. +1
    -1
      ge/graph/passes/infer_base_pass.h
  19. +26
    -6
      ge/graph/passes/infer_value_range_pass.cc
  20. +1
    -1
      ge/graph/passes/infer_value_range_pass.h
  21. +37
    -15
      ge/graph/passes/infershape_pass.cc
  22. +2
    -1
      ge/graph/passes/infershape_pass.h
  23. +2
    -3
      ge/graph/passes/merge_pass.cc
  24. +1
    -2
      ge/graph/passes/switch_dead_branch_elimination.cc
  25. +1
    -0
      ge/graph/preprocess/insert_op/util_insert_aipp_op.cc
  26. +1
    -1
      ge/graph/preprocess/multi_batch_copy_graph.cc
  27. +5
    -3
      ge/hybrid/executor/hybrid_model_async_executor.cc
  28. +0
    -4
      ge/hybrid/executor/hybrid_model_executor.cc
  29. +0
    -1
      ge/hybrid/executor/hybrid_model_pipeline_executor.cc
  30. +9
    -2
      ge/hybrid/executor/worker/task_compile_engine.cc
  31. +1
    -0
      ge/hybrid/model/hybrid_model_builder.cc
  32. +23
    -2
      ge/hybrid/model/node_item.cc
  33. +7
    -2
      ge/ir_build/option_utils.cc
  34. +4
    -3
      ge/offline/main.cc
  35. +68
    -85
      ge/single_op/single_op_model.cc
  36. +5
    -2
      ge/single_op/single_op_model.h
  37. +5
    -1
      ge/single_op/task/op_task.h
  38. +24
    -1
      ge/single_op/task/tbe_task_builder.cc
  39. +1
    -0
      ge/single_op/task/tbe_task_builder.h
  40. +1
    -1
      metadef
  41. +1
    -1
      parser
  42. +15
    -0
      scripts/env/Dockerfile
  43. +2
    -2
      scripts/env/ge_env.sh
  44. +1
    -0
      tests/depends/cce/CMakeLists.txt
  45. +0
    -13
      tests/framework/CMakeLists.txt
  46. +19
    -3
      tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h
  47. +6
    -2
      tests/framework/easy_graph/src/layout/graph_layout.cc
  48. +37
    -0
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h
  49. +32
    -0
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h
  50. +32
    -17
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h
  51. +59
    -0
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h
  52. +2
    -4
      tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h
  53. +26
    -0
      tests/framework/ge_graph_dsl/src/assert/assert_error.cc
  54. +34
    -0
      tests/framework/ge_graph_dsl/src/assert/check_utils.cc
  55. +31
    -0
      tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc
  56. +33
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h
  57. +79
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc
  58. +49
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h
  59. +32
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h
  60. +28
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc
  61. +41
    -0
      tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h
  62. +0
    -0
      tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc
  63. +12
    -5
      tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc
  64. +1
    -3
      tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc
  65. +3
    -9
      tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc
  66. +0
    -0
      tests/framework/ge_graph_dsl/src/vistor/ge_subgraph_vistor.cc
  67. +0
    -0
      tests/framework/ge_graph_dsl/src/vistor/graph_dsl.cc
  68. +1
    -1
      tests/framework/ge_graph_dsl/tests/CMakeLists.txt
  69. +129
    -0
      tests/framework/ge_graph_dsl/tests/check_graph_test.cc
  70. +16
    -28
      tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc
  71. +6
    -0
      tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc
  72. +25
    -22
      tests/framework/ge_graph_dsl/tests/test_main.cc
  73. +0
    -48
      tests/framework/utils/builder/graph_builder_utils.cc
  74. +0
    -55
      tests/framework/utils/builder/graph_builder_utils.h
  75. +1
    -1
      tests/st/testcase/CMakeLists.txt
  76. +49
    -78
      tests/st/testcase/test_framework_dummy.cc
  77. +4
    -16
      tests/st/testcase/test_ge_opt_info.cc
  78. +2
    -2
      tests/st/testcase/test_main.cc
  79. +1
    -0
      tests/ut/common/graph/CMakeLists.txt
  80. +1
    -0
      tests/ut/ge/CMakeLists.txt
  81. +3
    -1
      tests/ut/ge/graph/build/task_generator_unittest.cc
  82. +1
    -1
      tests/ut/ge/graph/passes/addn_pass_unittest.cc
  83. +462
    -11
      tests/ut/ge/graph/passes/base_pass_unittest.cc
  84. +45
    -0
      tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc
  85. +104
    -6
      tests/ut/ge/graph/passes/infershape_pass_unittest.cc
  86. +28
    -0
      tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc
  87. +12
    -0
      tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc
  88. +1
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc
  89. +1
    -0
      tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc
  90. +20
    -1
      tests/ut/ge/single_op/single_op_model_unittest.cc
  91. +1
    -0
      tests/ut/ge/single_op/single_op_task_unittest.cc
  92. +7
    -0
      third_party/fwkacllib/inc/external/runtime/rt_error_codes.h
  93. +4
    -4
      third_party/fwkacllib/inc/runtime/base.h
  94. +43
    -0
      third_party/fwkacllib/inc/runtime/config.h
  95. +5
    -0
      third_party/fwkacllib/inc/runtime/dev.h
  96. +35
    -2
      third_party/fwkacllib/inc/runtime/event.h
  97. +66
    -17
      third_party/fwkacllib/inc/runtime/kernel.h
  98. +11
    -0
      third_party/fwkacllib/inc/runtime/mem.h
  99. +11
    -2
      third_party/fwkacllib/inc/runtime/rt_model.h
  100. +30
    -1
      third_party/fwkacllib/inc/toolchain/prof_callback.h

+ 1
- 2
.clang-format View File

@@ -52,7 +52,6 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
@@ -94,7 +93,7 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
PointerAlignment: Right
RawStringFormats:
- Language: Cpp
Delimiters:


+ 2
- 0
ge/common/CMakeLists.txt View File

@@ -95,6 +95,7 @@ target_link_libraries(ge_common PRIVATE
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>>
static_mmpa
-Wl,--no-as-needed
graph
@@ -155,6 +156,7 @@ target_link_libraries(ge_common_static PRIVATE
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>>
ascend_protobuf_static
json
c_sec


+ 1
- 1
ge/common/dump/dump_properties.cc View File

@@ -163,7 +163,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDum
GELOGE(PARAM_INVALID, "[Check][Param] the dumpPath:%s, is invalid.", input.c_str());
return PARAM_INVALID;
}
if (mmAccess2(trusted_path, R_OK | W_OK) != EN_OK) {
if (mmAccess2(trusted_path, M_R_OK | M_W_OK) != EN_OK) {
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}),
std::vector<std::string>({
"ge.exec.dumpPath",


+ 1
- 0
ge/common/dump/exception_dumper.cc View File

@@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> &ex

uint64_t proto_size = dump_data.ByteSizeLong();
std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]);
GE_CHECK_NOTNULL(proto_msg);
bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size);
if (!ret || proto_size == 0) {
REPORT_INNER_ERROR("E19999", "Serialize proto to string fail");


+ 4
- 0
ge/executor/CMakeLists.txt View File

@@ -186,6 +186,8 @@ target_include_directories(ge_executor SYSTEM PRIVATE
${CMAKE_BINARY_DIR}/proto/graphengine_protos
#### yellow zone ####
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:${GE_DEPEND_DIR}/inc>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<TARGET_PROPERTY:runtime_headers,INTERFACE_INCLUDE_DIRECTORIES>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<TARGET_PROPERTY:cce_headers,INTERFACE_INCLUDE_DIRECTORIES>>
#### blue zone ####
$<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc>
$<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain>
@@ -251,6 +253,8 @@ target_link_libraries(ge_executor_shared PRIVATE
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>>
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:cce_headers>>
-Wl,--no-as-needed
ge_common
runtime


+ 1
- 1
ge/ge_runtime/task/label_goto_task.cc View File

@@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() {
return false;
}

rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size);
rt_ret = rtLabelListCpy(reinterpret_cast<void**>(label_list.data()), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;


+ 1
- 2
ge/graph/build/model_builder.cc View File

@@ -32,7 +32,6 @@
#include "graph/ge_attr_value.h"
#include "graph/ge_context.h"
#include "external/graph/ge_error_codes.h"
#include "graph/manager/graph_mem_allocator.h"
#include "graph/manager/graph_var_manager.h"
#include "graph/optimize/common/params.h"
#include "external/graph/types.h"
@@ -707,7 +706,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) {
if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) {
GE_CHECK_NOTNULL(kernel_buffer.GetData());
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize());
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data));
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data));
GE_CHECK_NOTNULL(tbe_kernel);
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(),
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());


+ 20
- 24
ge/graph/build/task_generator.cc View File

@@ -793,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
GELOGI("Start AutoFindBpOpIndex");
NodePtr bp_node = nullptr;
uint32_t current_idx = 0;
uint32_t netoutput_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@@ -811,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) {
if (bp_node == nullptr) {
bp_node = node;
netoutput_idx = current_idx - 1;
}
}
if (graph->GetNeedIteration()) {
@@ -836,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (bp_node == nullptr) {
GELOGW("not find bp_node.");
return SUCCESS;
} else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) {
profiling_point.bp_index = netoutput_idx;
GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx);
} else {
profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node);
}

return SUCCESS;
return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index);
}

uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const {
uint32_t last_bp = 0;
Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node,
uint32_t &bp_index) const {
bp_index = 0;
auto target_desc = target_node->GetOpDesc();
GE_CHECK_NOTNULL(target_desc);
OpDescPtr bp_op_desc = nullptr;
for (auto &in_anchor : bp_node->GetAllInDataAnchors()) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) {
continue;
}
auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc();
GE_CHECK_NOTNULL(out_node_desc);
if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) {
bp_op_desc = out_node_desc;
for (auto &in_node : target_node->GetInAllNodes()) {
GE_CHECK_NOTNULL(in_node);
auto in_node_desc = in_node->GetOpDesc();
GE_CHECK_NOTNULL(in_node_desc);
if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) &&
(in_node_desc->GetStreamId() == target_desc->GetStreamId())){
bp_op_desc = in_node_desc;
}
GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId());
}

if (bp_op_desc == nullptr) {
return last_bp;
GELOGI("Did not find bp node.");
return SUCCESS;
}
uint32_t current_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
@@ -871,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const
GE_CHECK_NOTNULL(op_desc);
current_idx++;
if (op_desc->GetName() == bp_op_desc->GetName()) {
last_bp = current_idx;
GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp);
bp_index = current_idx;
GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index);
break;
}
}
return last_bp;
GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(),
bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId());
return SUCCESS;
}

Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,


+ 1
- 1
ge/graph/build/task_generator.h View File

@@ -116,7 +116,7 @@ class TaskGenerator {
Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const;
Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
vector<uint32_t> &all_reduce_nodes) const;
uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const;
Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const;

Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,
ProfilingPoint &profiling_point) const;


+ 2
- 0
ge/graph/load/model_manager/task_info/kernel_task_info.cc View File

@@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne
GE_CHECK_NOTNULL(op_desc);

args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
GE_CHECK_NOTNULL(args_addr);
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) {
REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret);
@@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k

// copy args to new host memory
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
GE_CHECK_NOTNULL(args_addr);
GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_)
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) {


+ 0
- 49
ge/graph/manager/graph_var_manager.cc View File

@@ -194,35 +194,6 @@ ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_na
return SUCCESS;
}

ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
GE_CHECK_NOTNULL(base_ptr);
GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());

VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;

return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
var_broadcast_info.input_size, session_id_);
}

ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());

VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
// subgraph base_ptr could be nullptr, task it as base 0
uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;

return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
var_tensor_desc, session_id_);
}

ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
}

bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }

rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
@@ -638,16 +609,6 @@ bool VarManager::IsVarExist(const std::string &var_name) {
return var_resource_->IsVarExist(var_name);
}

ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
uint8_t *base_ptr) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) {
GELOGW("VarManager has not been init.");
return ge::INTERNAL_ERROR;
}
return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr);
}

ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
@@ -701,16 +662,6 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt
return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
}

ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) {
GELOGW("VarManager has not been init.");
return ge::INTERNAL_ERROR;
}
return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr);
}

bool VarManager::IsVarAddr(const int64_t &offset) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) {


+ 0
- 15
ge/graph/manager/graph_var_manager.h View File

@@ -118,15 +118,6 @@ class VarResource {

ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info);

ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr);

ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr);

ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
uint8_t *base_ptr);

Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) {
GELOGW("Var name: %s has already set.", var_name.c_str());
@@ -234,16 +225,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {

ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr);

ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
uint8_t *base_ptr);

ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info);

ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info);

ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
uint8_t *base_ptr);

ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc);

ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc);


+ 0
- 66
ge/graph/manager/trans_var_data_utils.cc View File

@@ -415,72 +415,6 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src,
return SUCCESS;
}
} // namespace
Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) {
GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr.");
uint8_t *src_host_addr = nullptr;
int64_t src_addr_size = 0;
GE_MAKE_GUARD_RTMEM(src_host_addr);
GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id));

GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size);
GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED,
"[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld",
src_addr_size, dst_addr_size);

GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE));
return SUCCESS;
}

Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name,
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. ");
uint8_t *host_addr = nullptr;
GE_MAKE_GUARD_RTMEM(host_addr);
GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size));
GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST));

GE_CHK_STATUS_RET(
SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id));

return SUCCESS;
}

Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) {
GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed");

uint8_t *src_addr = nullptr;
GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr));
uint8_t *mem_addr =
src_addr -
static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
static_cast<int64_t>(
reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size));

GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST));

GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size);
return SUCCESS;
}

Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size,
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
uint8_t *dst_addr = nullptr;
GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr));
uint8_t *mem_addr =
dst_addr -
static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
static_cast<int64_t>(
reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE));

GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size);

return SUCCESS;
}

Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes,
uint64_t session_id,
rtContext_t context,


+ 0
- 11
ge/graph/manager/trans_var_data_utils.h View File

@@ -29,11 +29,6 @@
namespace ge {
class TransVarDataUtils {
public:
static ge::Status SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id_);
static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name,
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_);

static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes,
uint64_t session_id,
rtContext_t context,
@@ -41,12 +36,6 @@ class TransVarDataUtils {
uint32_t thread_num = 16);

static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id);

private:
static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_);
static ge::Status SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size,
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_);
};
} // namespace ge



+ 138
- 91
ge/graph/passes/base_pass.cc View File

@@ -46,50 +46,116 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph,

bool AllNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) {
return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) {
return nodes_set.count(n) == 0;
});
}

bool AnyNodesIn(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_set) {
return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) {
return nodes_set.count(n) > 0;
});
}

void AddNextIterNodes(const NodePtr &cur_node, GEPass::GraphLevelState &g_state) {
const auto &nodes_suspend = g_state.nodes_suspend;
bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) {
if (node == nullptr) {
GELOGW("node is null");
return false;
}

if (g_state.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
return false;
}

if (g_state.nodes_last.count(node) != 0) {
return false;
}

if (!node->IsAllInNodesSeen(g_state.nodes_seen)) {
return false;
}

// 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同
// TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend,
// 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响
// 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue,
// 这样处理就不需要在这里做额外的确认了
if (g_state.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
return true;
}

void CollectOutNodesBeforePass(const NodePtr &node, std::unordered_set<NodePtr> &out_nodes_before_pass) {
for (const auto &out_node : node->GetOutNodes()) {
out_nodes_before_pass.insert(out_node);
}
}

void AddNextIterNodes(const NodePtr &cur_node, std::unordered_set<NodePtr> &out_nodes_before_pass,
GEPass::GraphLevelState &g_state) {
for (auto &node : cur_node->GetOutNodes()) {
if (node == nullptr) {
continue;
}
if (g_state.nodes_last.count(node) != 0) {
continue;
if (out_nodes_before_pass.erase(node) == 0) {
// after pass node , new output node come up
GELOGD("New output nodes %s come up after pass %s.", node->GetName().c_str(), cur_node->GetName().c_str());
}

if (IsNodeReadyToQueue(node, g_state)) {
g_state.AddNodeToQueueIfNotSeen(node);
}
if (nodes_suspend.count(node) > 0) {
GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str());
}
// A-->B-->C
// \
// D--->E
// If B has been delete after pass, two case need to consider
// 1. A & C & E has been repass by B. good choice
// 2. A & C & E not added to repass, C will not pass because no one trigger it.
// while E will pass because D will trigger it.
// So here we need add node which has no input_node to queue.
for (const auto &node : out_nodes_before_pass) {
if (!node->GetInAllNodes().empty()) {
GELOGD("Node %s used to be output of node %s, but after pass it doesnt. "
"It may triggered by other node, so no need add to queue now.");
continue;
}

if (node->IsAllInNodesSeen(g_state.nodes_seen) && AllNodesIn(node->GetInAllNodes(), nodes_suspend)) {
if (IsNodeReadyToQueue(node, g_state)) {
// unlink edge may happen, add these node to queue otherwise they can not pass
GELOGI("Node %s may lost from cur node, add to queue if not seen.",
node->GetName().c_str(), cur_node->GetName().c_str());
g_state.AddNodeToQueueIfNotSeen(node);
}
}
}

void AddImmediateRepassNodesToQueue(NodePtr &cur_node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
const std::unordered_set<NodePtr> &nodes_im_re_pass,
void AddImmediateRepassNodesToQueue(NodePtr &cur_node,
const std::unordered_map<NodePtr, std::string> re_pass_imm_nodes_to_pass_names,
GEPass::GraphLevelState &g_state) {
for (const auto &node : nodes_im_re_pass) {
if (node == nullptr) {
GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", name_to_pass.first.c_str(),
for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) {
auto repass_imm_node = node_2_pass_names.first;
if (repass_imm_node == nullptr) {
GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s",
node_2_pass_names.second.c_str(),
cur_node->GetName().c_str(), cur_node->GetType().c_str());
continue;
}
if (g_state.nodes_passed.count(node) > 0) {
g_state.AddNodeToQueueFront(node);
continue;
}
// exp: constant folding add new const need repass immediate
if (AllNodesIn(node->GetInAllNodes(), g_state.nodes_passed)) {
g_state.AddNodeToQueueFront(node);
if (g_state.nodes_passed.count(repass_imm_node) > 0) {
GELOGD("The node %s specified by pass %s has been passed, it will repass immediately",
repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str());
g_state.AddNodeToQueueFront(repass_imm_node);
continue;
}
GELOGW("The node %s specified by pass %s has un-passed in_nodes, it will not repass immediately",
node->GetName().c_str(), name_to_pass.first.c_str());
repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str());
}
}

@@ -103,23 +169,18 @@ void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) {
g_state.nodes_last.clear();
}

void SuspendAndResume(const std::string &pass_name,
const std::unordered_set<NodePtr> &nodes_suspend,
const std::unordered_set<NodePtr> &nodes_resume,
GEPass::GraphLevelState &g_state) {
// TODO 当前没有记录NodePass中suspend和resume的顺序,因此无法辨别NodePass中是先做Suspend还是Resume。
// 因此此处的简单处理是如果在NodePass的过程中,触发了suspend/resume,那么框架以resume为准
// 更好的处理方式是,在NodePass做suspend/resume时,做顺序的记录,在此函数中按序做回放
for (const auto &node : nodes_suspend) {
GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str());
g_state.nodes_suspend.insert(node);
}

for (const auto &node : nodes_resume) {
void AddResumeNodesToQueue(const std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names,
GEPass::GraphLevelState &g_state) {
// Currently we dont keep the order of suspend nodes and resume nodes, so its hard to know
// which one comes first. Simple way : if a node both have suspend & resume state, we will resume it.
// Better way: keep the order when suspend/resume a node, and in this func suspend/resume in order.
for (const auto &node_2_pass_names : resume_nodes_to_pass_names) {
auto node = node_2_pass_names.first;
if (g_state.nodes_suspend.erase(node) > 0) {
if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) {
g_state.nodes.push_back(node);
GELOGD("Node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str());
GELOGD("Node %s has been resumed by pass %s, add to queue.",
node->GetName().c_str(), node_2_pass_names.second.c_str());
}
}
}
@@ -154,36 +215,6 @@ void ClearOption(NamesToPass names_to_pass) {
name_to_pass.second->ClearOptions();
}
}

bool ShouldNodePassActually(const NodePtr &node, const GEPass::GraphLevelState &g_state) {
if (node == nullptr) {
GELOGW("node is null");
return false;
}
// 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,删除了node的输出节点,
// 那么会出现:已经删除的节点出现在queue中,并且被pop出来,因此这里做确认,如果node已经被删除过了,就跳过pass
if (g_state.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
return false;
}

// 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同
// TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend,
// 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响
// 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue,
// 这样处理就不需要在这里做额外的确认了
if (g_state.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
if (!AllNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}
return true;
}
} // namespace

Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map,
@@ -256,18 +287,20 @@ void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names
}

Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) {
std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names;
for (auto &name_to_pass : names_to_passes) {
name_to_pass.second->init();
auto ret = name_to_pass.second->OnSuspendNodesLeaked();
if (ret != SUCCESS) {
// todo error
GELOGE(ret, "Internal Error happened when pass %s handle on suspend nodes leaked.",
name_to_pass.first.c_str());
return ret;
}
SuspendAndResume(name_to_pass.first,
name_to_pass.second->GetNodesSuspend(),
name_to_pass.second->GetNodesResume(),
g_state);
for (const auto &resume_node : name_to_pass.second->GetNodesResume()){
resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ",");
}
}
AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state);
return SUCCESS;
}

@@ -283,11 +316,13 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
if (!g_state.nodes_suspend.empty()) {
auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state);
if (ret != SUCCESS) {
// todo log
GELOGE(ret, "Failed to handle leaked suspend nodes, break base pass.");
return ret;
}
if (g_state.nodes.empty()) {
// todo 报错,因为suspend泄露场景,没有子类做进一步的resume,此处可能已经彻底泄露,需要报错
// There are suspend nodes leaked, but no pass resume it
GELOGE(INTERNAL_ERROR, "There are suspend nodes but no pass resume, which means"
"some nodes in this graph never pass.");
return INTERNAL_ERROR;
}
}
@@ -305,6 +340,7 @@ Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLev
RepassLevelState rp_state;
do {
for (auto &node : rp_state.nodes_re_pass) {
GELOGD("Add node %s to queue for re-pass.", node->GetName().c_str());
g_state.AddNodeToQueue(node);
}
rp_state.nodes_re_pass.clear();
@@ -312,12 +348,14 @@ Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLev
while (!g_state.nodes.empty()) {
auto node = g_state.PopFront();

(void)rp_state.nodes_re_pass.erase(node); // todo 回忆一下为什么
if (!ShouldNodePassActually(node, g_state)) {
continue;
if (g_state.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
}
(void)rp_state.nodes_re_pass.erase(node);// todo why
g_state.nodes_seen.insert(node.get()); // todo 为什么这里seen
AddNextIterNodes(node, g_state);

std::unordered_set<NodePtr> out_nodes_before_pass;
CollectOutNodesBeforePass(node, out_nodes_before_pass);

auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state);
if (ret != SUCCESS) {
@@ -325,6 +363,7 @@ Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLev
node->GetType().c_str(), ret);
return ret;
}
AddNextIterNodes(node, out_nodes_before_pass, g_state);
}
AddLastNodesToQueue(g_state);
} while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes);
@@ -405,10 +444,11 @@ Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes
name_to_pass.second->init();
auto result = name_to_pass.second->Run(node);
if (result != SUCCESS) {
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result "
"%u, the passes will be terminated immediately.",
REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", name_to_pass.first.c_str(),
node->GetName().c_str(), result);
GELOGE(INTERNAL_ERROR,
"[Process][Pass] %s on node %s failed, result "
"%u, the passes will be terminated immediately.",
name_to_pass.first.c_str(), node->GetName().c_str(), result);
return result;
}
@@ -421,23 +461,30 @@ Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes

g_state.nodes_passed.insert(node);

std::unordered_map<NodePtr, std::string> repass_imm_nodes_to_pass_names;
std::unordered_map<NodePtr, std::string> resume_nodes_to_pass_names;
// if multi pass add one node to repass immediately, here need to remove duplication
for (const auto &name_to_pass : names_to_passes) {
PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen,
name_to_pass.second->GetNodesNeedRePass(),
PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, name_to_pass.second->GetNodesNeedRePass(),
rp_state.nodes_re_pass);
// collect imm_node && resume_node among these passes
for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()) {
repass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ",");
}
for (const auto &resume_node : name_to_pass.second->GetNodesResume()) {
resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ",");
}

AddImmediateRepassNodesToQueue(node, name_to_pass,
name_to_pass.second->GetNodesNeedRePassImmediately(),
g_state);
SuspendAndResume(name_to_pass.first,
name_to_pass.second->GetNodesSuspend(),
name_to_pass.second->GetNodesResume(),
g_state);

for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) {
GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(),
name_to_pass.first.c_str());
g_state.nodes_suspend.insert(suspend_node);
}
const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
}

AddImmediateRepassNodesToQueue(node, repass_imm_nodes_to_pass_names, g_state);
AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state);
return SUCCESS;
}
} // namespace ge

+ 1
- 1
ge/graph/passes/folding_pass.cc View File

@@ -363,7 +363,7 @@ Status FoldingPass::ConnectNodeToInAnchor(InDataAnchorPtr &in_anchor, NodePtr &n
in_anchor->GetIdx());
return INTERNAL_ERROR;
}
AddImmediateRePassNode(node);
AddRePassNodesWithInOut(node);
return SUCCESS;
}
} // namespace ge

+ 5
- 2
ge/graph/passes/infer_base_pass.cc View File

@@ -84,8 +84,11 @@ Status InferBasePass::Run(NodePtr &node) {

bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; }
void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) {
// need passed_nodes set to solve the problem that multi-input operators do repass in advance.
// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes.
// need passed_nodes set to solve the problem that multi-input operators do repass in advance.
// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes.
for (const auto &node : changed_nodes) {
AddImmediateRePassNode(node);
}
}

graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {


+ 1
- 1
ge/graph/passes/infer_base_pass.h View File

@@ -36,7 +36,7 @@ class InferBasePass : public BaseNodePass {
* @param dst, output TensorDesc to be updated
* @return
*/
virtual graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0;
virtual graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0;

/**
* Update the output TensorDesc for nodes which contain subgraphs.


+ 26
- 6
ge/graph/passes/infer_value_range_pass.cc View File

@@ -207,7 +207,7 @@ bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const {
return has_unknown_value_range;
}

graphStatus InferValueRangePass::UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
if (src == nullptr || dst == nullptr) {
REPORT_CALL_ERROR("E19999", "While updating tensor desc, input desc is null.");
GELOGE(GRAPH_FAILED, "[Param][check] While updating tensor desc, input desc is null.");
@@ -301,12 +301,26 @@ graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc,
GeTensorPtr &output_ptr) {
std::vector<std::pair<int64_t, int64_t>> value_range;
(void)tensor_desc.GetValueRange(value_range);
if (static_cast<int64_t>(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) {
GELOGW("Value range of input %s is invalid.", tensor_desc.GetName().c_str());
size_t value_range_data_num = value_range.size();
auto tensor_shape = tensor_desc.GetShape();
bool value_range_and_tensor_shape_matched = true;
if (tensor_shape.IsScalar()){
// scalar tensor has only one value_range pair
if (value_range_data_num != 1) {
value_range_and_tensor_shape_matched = false;
}
} else {
// normal tensor, value_range size is equal to tensor shape size.
if (static_cast<int64_t>(value_range_data_num) != tensor_shape.GetShapeSize()) {
value_range_and_tensor_shape_matched = false;
}
}
if (!value_range_and_tensor_shape_matched) {
GELOGW("Input %s value range and tensor shape do not match. Value range size is %zu, tensor shape is %s.",
tensor_desc.GetName().c_str(), value_range_data_num, formats::ShapeToString(tensor_shape).c_str());
return GRAPH_PARAM_INVALID;
}

size_t value_range_data_num = value_range.size();
unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]());
if (buf == nullptr) {
REPORT_INNER_ERROR("E19999", "New buf failed");
@@ -494,10 +508,16 @@ void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, co
GELOGI("Output tensor of cpu kernel does not have data, no way to set value range.");
return;
}
for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) {
auto left_tensor_shape = left_tensor->GetTensorDesc().GetShape();
for (auto j = 0; j < left_tensor_shape.GetShapeSize(); ++j) {
auto left = static_cast<int64_t>(*(x + j));
auto right = static_cast<int64_t>(*(y + j));
value_range.emplace_back(std::make_pair(left, right));
value_range.emplace_back(left, right);
}

if (left_tensor_shape.IsScalar()) {
GELOGD("When inferring value range, output tensors of cpu kernel are scalar tensors.");
value_range.emplace_back(static_cast<int64_t>(*x), static_cast<int64_t>(*y));
}
}
} // namespace ge

+ 1
- 1
ge/graph/passes/infer_value_range_pass.h View File

@@ -26,7 +26,7 @@ class InferValueRangePass : public InferBasePass {

private:
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override;
graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override;
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) override;


+ 37
- 15
ge/graph/passes/infershape_pass.cc View File

@@ -90,6 +90,9 @@ Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) {
GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(),
anchor_2_node.second->GetType().c_str());
auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName());
if (iter == graphs_2_suspend_nodes_.end()) {
continue;
}
auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()];
if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) {
suspend_nodes.nodes.push(anchor_2_node.second);
@@ -102,7 +105,7 @@ Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) {
Status InferShapePass::Infer(NodePtr &node) {
auto ret = SuspendV1LoopExitNodes(node);
if (ret != SUCCESS) {
//todo LOG
GELOGE(ret, "Failed to suspend exit node in v1 control flow loop.");
return ret;
}
bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
@@ -117,7 +120,9 @@ Status InferShapePass::Infer(NodePtr &node) {
if (!is_unknown_graph) {
auto inference_context = ShapeRefiner::CreateInferenceContext(node);
GE_CHECK_NOTNULL(inference_context);
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
vector<AscendString> marks;
inference_context->GetMarks(marks);
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size());
op.SetInferenceContext(inference_context);
}

@@ -128,13 +133,16 @@ Status InferShapePass::Infer(NodePtr &node) {
GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str());
return GRAPH_FAILED;
}
UpdateCurNodeOutputDesc(node);
if (!is_unknown_graph) {
auto ctx_after_infer = op.GetInferenceContext();
if (ctx_after_infer != nullptr) {
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
vector<AscendString> marks;
ctx_after_infer->GetMarks(marks);
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) {
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
ctx_after_infer->GetMarks().size());
marks.size());
ShapeRefiner::PushToContextMap(node, ctx_after_infer);
}
}
@@ -180,15 +188,29 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe
return true;
}

graphStatus InferShapePass::UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
// refresh src itself
src->SetOriginShape(src->GetShape());
src->SetOriginDataType(src->GetDataType());
TensorUtils::SetRealDimCnt(*src, static_cast<uint32_t>(src->GetOriginShape().GetDims().size()));
vector<pair<int64_t, int64_t>> src_shape_range;
src->GetShapeRange(src_shape_range);
src->SetOriginShapeRange(src_shape_range);
void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) {
auto op_desc = node->GetOpDesc();
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
GE_IF_BOOL_EXEC(output_tensor == nullptr, continue);
GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(),
output_tensor->SetOriginShape(output_tensor->GetShape()));

ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims()
.size()));
output_tensor->SetOriginDataType(output_tensor->GetDataType());
// set output origin shape range
std::vector<std::pair<int64_t, int64_t>> range;
(void)output_tensor->GetShapeRange(range);
output_tensor->SetOriginShapeRange(range);
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
}
}

graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
changed = false;
if (SameTensorDesc(src, dst)) {
GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update.");
@@ -213,7 +235,7 @@ graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) {
auto ret = op_desc->CallInferFunc(op);
if (ret == GRAPH_PARAM_INVALID) {
// Op ir no infer func, try to get infer func from operator factory
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
return ret;
@@ -318,7 +340,7 @@ graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vec
Status InferShapePass::OnSuspendNodesLeaked() {
auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName());
if (iter == graphs_2_suspend_nodes_.end()) {
// todo log warn
GELOGW("There is no suspend nodes on graph %s", GetCurrentGraphName().c_str());
return SUCCESS;
}
if (!iter->second.nodes.empty()) {


+ 2
- 1
ge/graph/passes/infershape_pass.h View File

@@ -26,7 +26,7 @@ class InferShapePass : public InferBasePass {
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override;
graphStatus Infer(NodePtr &node) override;

graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override;
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
GeTensorDescPtr &dst) override;
@@ -36,6 +36,7 @@ class InferShapePass : public InferBasePass {
private:
graphStatus CallInferShapeFunc(NodePtr &node, Operator &op);
bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst);
void UpdateCurNodeOutputDesc(NodePtr &node);
Status SuspendV1LoopExitNodes(const NodePtr &node);
struct SuspendNodes {
std::stack<NodePtr> nodes;


+ 2
- 3
ge/graph/passes/merge_pass.cc View File

@@ -31,7 +31,6 @@ namespace ge {
const int kValueIndexOutputIndex = 1;
const size_t kCaseNoInput = 0;
const size_t kCaseOneInput = 1;
const bool kWillRepassImmediately = true;

Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running");
@@ -83,14 +82,14 @@ Status MergePass::Run(NodePtr &node) {
}
auto in_node = in_data_nodes.at(0);
if (IsMergeInputNeedOptimized(in_node)) {
if (IsolateAndDeleteNode(in_node, {0}, kWillRepassImmediately) != SUCCESS) {
if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed",
in_node->GetName().c_str(), in_node->GetType().c_str());
GELOGE(FAILED, "[Remove][Node] %s failed.", in_node->GetName().c_str());
return FAILED;
}
}
return IsolateAndDeleteNode(node, merge_io_map, kWillRepassImmediately);
return IsolateAndDeleteNode(node, merge_io_map);
}
default: {
// Case C: input_count > 1, the merge node can not be optimized


+ 1
- 2
ge/graph/passes/switch_dead_branch_elimination.cc View File

@@ -28,7 +28,6 @@ namespace {
const std::vector<int>::size_type kDataInputIndex = 0;
const std::vector<int>::size_type kPredInputIndex = 1;
const int kDefaultInputIndex = -1;
const bool kWillRepassImmediately = true;

bool ParsePred(const ConstGeTensorPtr &tensor) {
if (tensor == nullptr) {
@@ -135,7 +134,7 @@ Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pre
return FAILED;
}
switch_io_map[out_index] = kDataInputIndex;
return IsolateAndDeleteNode(node, switch_io_map, kWillRepassImmediately);
return IsolateAndDeleteNode(node, switch_io_map);
}

Status SwitchDeadBranchElimination::Run(NodePtr &node) {


+ 1
- 0
ge/graph/preprocess/insert_op/util_insert_aipp_op.cc View File

@@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std:
}

std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams());
GE_CHECK_NOTNULL(aipp_params);
ge::GeAttrValue::NAMED_ATTRS aipp_attr;
GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST,
"[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str());


+ 1
- 1
ge/graph/preprocess/multi_batch_copy_graph.cc View File

@@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_
auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims();
if (!IsAllDimsPositive(dims)) {
REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s",
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s",
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
return INTERNAL_ERROR;


+ 5
- 3
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -295,13 +295,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData &current_data, Hy
}
}
tensor_desc->SetShape(shape);
args.input_desc[input_index] = tensor_desc;
GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str());
GELOGD("Update shape[%s] of input[%zu] to [%s]",
shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str());
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size),
"[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size,"
"index = %zu, shape = [%s], model_id = %u.",
input_index, tensor_desc->GetShape().ToString().c_str(), model_id_);
GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size);
GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size);
TensorUtils::SetSize(*tensor_desc, tensor_size);
args.input_desc[input_index] = tensor_desc;
}

GE_CHECK_GE(tensor_size, 0);


+ 0
- 4
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id,
}

HybridModelExecutor::~HybridModelExecutor() {
if (context_.rt_gen_context != nullptr) {
(void) rtCtxDestroy(context_.rt_gen_context);
}
}

Status HybridModelExecutor::Init() {
@@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() {

Status HybridModelExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context));
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));

context_.global_step = model_->GetGlobalStep();


+ 0
- 1
ge/hybrid/executor/hybrid_model_pipeline_executor.cc View File

@@ -191,7 +191,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin
}

Status StageExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));

context_.model = model_;


+ 9
- 2
ge/hybrid/executor/worker/task_compile_engine.cc View File

@@ -21,10 +21,17 @@
namespace ge {
namespace hybrid {
Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) {
const auto &node_item = *node_state.GetNodeItem();
GE_CHECK_NOTNULL(context);
rtContext_t rt_gen_context = nullptr;
GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0));
std::function<void()> callback = [&]() {
(void) rtCtxDestroy(rt_gen_context);
GE_CHK_RT(rtCtxSetCurrent(context->rt_context));
};
GE_MAKE_GUARD(rt_gen_context, callback);

const auto &node_item = *node_state.GetNodeItem();
RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start");
GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context));

if (context->ge_context != nullptr) {
GetThreadLocalContext() = *context->ge_context;


+ 1
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -1044,6 +1044,7 @@ Status HybridModelBuilder::InitConstantOps() {
} else {
var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0));
}
GE_CHECK_NOTNULL(var_tensor);
} else {
GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor));
GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize());


+ 23
- 2
ge/hybrid/model/node_item.cc View File

@@ -24,6 +24,8 @@
namespace ge {
namespace hybrid {
namespace {
const uint8_t kMaxTransCount = 3;
const uint32_t kTransOpIoSize = 1;
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char *const kNodeTypeRetVal = "_RetVal";
const std::set<std::string> kControlOpTypes{
@@ -39,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{
MERGE, REFMERGE, STREAMMERGE
};

bool IsEnterFeedNode(NodePtr node) {
// For: Enter -> node
// For: Enter -> Cast -> node
// For: Enter -> TransData -> Cast -> node
for (uint8_t i = 0; i < kMaxTransCount; ++i) {
if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str());
return true;
}

const auto all_nodes = node->GetInDataNodes();
if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) {
return false;
}
node = all_nodes.at(0);
}
return false;
}

Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0;
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
@@ -399,7 +420,7 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_anchors.emplace(anchor_index);
}
// If Enter feed Not Merge, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) {
auto &data_anchors = node_item->enter_data_[this];
data_anchors.emplace(anchor_index);
}
@@ -419,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
node_item->root_ctrl_.emplace(this);
}
// If Enter feed control signal, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
node_item->enter_ctrl_.emplace(this);
}
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());


+ 7
- 2
ge/ir_build/option_utils.cc View File

@@ -50,6 +50,8 @@ const std::set<std::string> kBufferOptimizeSupportOption = {"l1_optimize", "l2_o
const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL = "high_precision_for_all";
const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL = "high_performance_for_all";
const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\"";
const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\"";
const char *const kSplitError1 = "size not equal to 2 split by \":\"";
@@ -57,7 +59,8 @@ const char *const kEmptyError = "can not be empty";
const char *const kFloatNumError = "exist float number";
const char *const kDigitError = "is not digit";
const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]";
const char *const kSelectImplmodeError = "only support high_performance, high_precision";
const char *const kSelectImplmodeError = "only support high_performance, high_precision, "
"high_precision_for_all, high_performance_for_all";
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \"";
const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\"";
const char *const kKeepDtypeError = "file not found";
@@ -782,7 +785,9 @@ Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::
op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT;
} else {
if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT &&
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) {
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON &&
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL &&
op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--op_select_implmode", op_select_implmode.c_str(),
kSelectImplmodeError});


+ 4
- 3
ge/offline/main.cc View File

@@ -143,7 +143,8 @@ DEFINE_string(output_type, "",

DEFINE_string(op_select_implmode, "",
"Optional; op select implmode! "
"Support high_precision, high_performance.");
"Support high_precision, high_performance, "
"high_precision_for_all, high_performance_for_all.");

DEFINE_string(optypelist_for_implmode, "",
"Optional; Nodes need use implmode selected in op_select_implmode "
@@ -311,8 +312,8 @@ class GFlagUtils {
"scenarios by using a configuration file.\n"
" --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n"
" --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n"
" --op_select_implmode Set op select implmode. Support high_precision, high_performance. "
"default: high_performance\n"
" --op_select_implmode Set op select implmode. Support high_precision, high_performance, "
"high_precision_for_all, high_performance_for_all. default: high_performance\n"
" --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n"
" Separate multiple nodes with commas (,). Use double quotation marks (\") "
"to enclose each argument. E.g.: \"node_name1,node_name2\"\n"


+ 68
- 85
ge/single_op/single_op_model.cc View File

@@ -95,35 +95,6 @@ Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_ho
}
return SUCCESS;
}

Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) {
bool is_infer_depend = false;
bool is_host_mem = false;
GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed.");
bool need_d2h_cpy = is_infer_depend && !is_host_mem;
auto tasks = ge_model->GetModelTaskDefPtr()->task();
int32_t kernel_task_num = 0;
for (int i = 0; i < tasks.size(); ++i) {
auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type());
if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) {
const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() :
tasks[i].kernel_with_handle().context();
auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
if (need_d2h_cpy) {
flag = true;
return SUCCESS;
}
kernel_task_num++;
if (kernel_task_num > 1) {
flag = true;
return SUCCESS;
}
}
}
}
return SUCCESS;
}
} // namespace

SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size)
@@ -558,14 +529,15 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) {
return BuildTaskList(&resource, single_op);
}

Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, const TaskDef &task_def,
DynamicSingleOp &single_op) {
auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() :
task_def.kernel_with_handle().context();
Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) {
auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);

auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph());
GE_CHECK_NOTNULL(compute_graph);
single_op.compute_graph_ = compute_graph;
if (tbe_tasks_.size() > 0) {
const auto &task_def = tbe_tasks_[0];
GELOGD("Building TBE task.");
TbeOpTask *tbe_task = nullptr;
GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task));
@@ -575,71 +547,81 @@ Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, cons
tbe_task->stream_resource_ = stream_resource;
}
single_op.op_task_.reset(tbe_task);
} else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) {
GELOGD("Building AICPU_CC task");
OpTask *task = nullptr;
uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++;
GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id);
GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id));
task->SetModelArgs(model_name_, model_id_);
single_op.op_task_.reset(task);
} else {
GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID,
"[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u",
context.kernel_type());
REPORT_INNER_ERROR("E19999",
"BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.",
context.kernel_type());
return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID;
}
return SUCCESS;
}

Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) {
auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);

auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph());
GE_CHECK_NOTNULL(compute_graph);
single_op.compute_graph_ = compute_graph;
auto tasks = ge_model->GetModelTaskDefPtr()->task();
for (int i = 0; i < tasks.size(); ++i) {
const TaskDef &task_def = tasks[i];
GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(),
task_def.DebugString().c_str());
} else if (aicpu_tasks_.size() > 0) {
const auto &task_def = aicpu_tasks_[0];
auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) {
if (single_op.op_task_ != nullptr) {
GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks.");
REPORT_INNER_ERROR("E19999",
"BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks.");
return ACL_ERROR_GE_OP_TASK_TYPE_INVALID;
}
GE_CHK_STATUS_RET_NOLOG(BuildModelTaskKernel(stream_resource, task_def, single_op));
if (task_type == RT_MODEL_TASK_KERNEL) {
GELOGD("Building AICPU_CC task");
OpTask *task = nullptr;
uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++;
GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id);
GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id));
task->SetModelArgs(model_name_, model_id_);
single_op.op_task_.reset(task);
} else if (task_type == RT_MODEL_TASK_KERNEL_EX) {
if (single_op.op_task_ != nullptr) {
GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks.");
REPORT_INNER_ERROR("E19999",
"BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks.");
return ACL_ERROR_GE_OP_TASK_TYPE_INVALID;
}
GELOGD("Building AICPU_TF task");
AiCpuTask *aicpu_task = nullptr;
uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++;
GELOGI("Build dynamic singleOp TfTask, kernel_id = %lu", dynamic_singleop_kernel_id);
GE_CHK_STATUS_RET_NOLOG(BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, dynamic_singleop_kernel_id));
if (aicpu_task->GetUnknownType() == DEPEND_COMPUTE) {
if (i >= tasks.size() - 1) {
if (aicpu_tasks_.size() < 2) {
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Task]The copy task of the fourth operator was not found.");
REPORT_INNER_ERROR("E19999", "The copy task of the fourth operator was not found.");
return ACL_ERROR_GE_PARAM_INVALID;
}
++i;
const TaskDef &copy_task_def = tasks[i];
const TaskDef &copy_task_def = aicpu_tasks_[1];
GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex()));
}
aicpu_task->SetModelArgs(model_name_, model_id_);
single_op.op_task_.reset(aicpu_task);
}
}
return SUCCESS;
}

Status SingleOpModel::NeedHybridModel(GeModelPtr &ge_model, bool &need_hybrid_model) {
bool is_infer_depend = false;
bool is_host_mem = false;
GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed.");
bool need_d2h_cpy = is_infer_depend && !is_host_mem;
bool aicpu_multi_task = tbe_tasks_.size() >= 1 && aicpu_tasks_.size() >= 1;
bool aicore_multi_task = tbe_tasks_.size() > 1;
need_hybrid_model = need_d2h_cpy || aicore_multi_task || aicpu_multi_task;
return SUCCESS;
}

Status SingleOpModel::ParseTasks() {
auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);

auto tasks = ge_model->GetModelTaskDefPtr()->task();
for (int i = 0; i < tasks.size(); ++i) {
TaskDef &task_def = tasks[i];
GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(),
task_def.DebugString().c_str());
auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
if (task_type == RT_MODEL_TASK_KERNEL) {
const auto &kernel_def = task_def.kernel();
const auto &context = kernel_def.context();
auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
tbe_tasks_.emplace_back(task_def);
} else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) {
aicpu_tasks_.emplace_back(task_def);
} else {
GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID,
"[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u",
context.kernel_type());
REPORT_INNER_ERROR("E19999",
"BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.",
context.kernel_type());
return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID;
}
} else if (task_type == RT_MODEL_TASK_ALL_KERNEL) {
tbe_tasks_.emplace_back(task_def);
} else if (task_type == RT_MODEL_TASK_KERNEL_EX) {
aicpu_tasks_.emplace_back(task_def);
} else {
// skip
GELOGD("Skip task type: %d", static_cast<int>(task_type));
@@ -654,6 +636,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp &
GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource));
model_params_.memory_size = UINT64_MAX;
model_params_.graph_is_dynamic = true;
GE_CHK_STATUS_RET(ParseTasks(), "[Parse][Tasks] failed.");

auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);


+ 5
- 2
ge/single_op/single_op_model.h View File

@@ -71,13 +71,16 @@ class SingleOpModel {
Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task);
Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id);
Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id);
Status BuildModelTaskKernel(StreamResource *stream_resource, const domi::TaskDef &task_def,
DynamicSingleOp &single_op);

static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam &param);
void ParseArgTable(OpTask *task, SingleOp &op);
Status InitHybridModelExecutor(const StreamResource &resource, const GeModelPtr &ge_model, SingleOp &single_op);
Status SetHostMemTensor(DynamicSingleOp &single_op);
Status NeedHybridModel(GeModelPtr &ge_model, bool &flag);
Status ParseTasks();

std::vector<domi::TaskDef> tbe_tasks_;
std::vector<domi::TaskDef> aicpu_tasks_;

std::string model_name_;
uint32_t model_id_ = 0;


+ 5
- 1
ge/single_op/task/op_task.h View File

@@ -33,6 +33,10 @@
#include "register/op_tiling.h"

namespace ge {
namespace {
const int kAddressNum = 2;
} // namespace

class StreamResource;
struct SingleOpModelParam;
class OpTask {
@@ -264,7 +268,7 @@ class MemcpyAsyncTask : public OpTask {
friend class SingleOpModel;
friend class RtsKernelTaskBuilder;

uintptr_t addresses_[2];
uintptr_t addresses_[kAddressNum];
size_t dst_max_;
size_t count_;
rtMemcpyKind_t kind_;


+ 24
- 1
ge/single_op/task/tbe_task_builder.cc View File

@@ -104,7 +104,7 @@ Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bi
binary.version = 0;
binary.data = kernel_bin.GetBinData();
binary.length = kernel_bin.GetBinDataSize();
binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC;
GE_CHK_STATUS_RET_NOLOG(GetMagic(binary.magic));
Status ret = 0;
if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) {
ret = rtRegisterAllKernel(&binary, bin_handle);
@@ -416,4 +416,27 @@ Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) {
task.EnableDynamicSupport(node_, tiling_buffer, static_cast<uint32_t>(max_size));
return SUCCESS;
}

Status TbeTaskBuilder::GetMagic(uint32_t &magic) const {
std::string json_string;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc_, TVM_ATTR_NAME_MAGIC, json_string),
GELOGD("Get original type of session_graph_id."));
if (json_string == "RT_DEV_BINARY_MAGIC_ELF") {
magic = RT_DEV_BINARY_MAGIC_ELF;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") {
magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") {
magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE;
} else {
REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%s check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(),
op_desc_->GetType().c_str(), json_string.c_str());
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in op:%s(%s), value:%s check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(),
op_desc_->GetType().c_str(), json_string.c_str());
return PARAM_INVALID;
}
return SUCCESS;
}

} // namespace ge

+ 1
- 0
ge/single_op/task/tbe_task_builder.h View File

@@ -105,6 +105,7 @@ class TbeTaskBuilder {
const SingleOpModelParam &param);
Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam &param) const;
Status DoRegisterMeta(void *bin_handle);
Status GetMagic(uint32_t &magic) const;

static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name);



+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 9e4a51a9602195b82e326b853f5adbfefc3972b6
Subproject commit 9c9907b76a457f456072af96b8cbcfb7943beccc

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit 79536a196f89cf7a1f5852ff7304b9a7d7b12eff
Subproject commit 15a27afefe45f2abdb78787d629163aab9437599

+ 15
- 0
scripts/env/Dockerfile View File

@@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l

ENV PROJECT_HOME=/code/Turing/graphEngine

RUN mkdir /var/run/sshd
RUN echo "root:root" | chpasswd
RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd

ENV NOTVISIBLE "in users profile"
RUN echo "export VISIBLE=now" >> /etc/profile

EXPOSE 22 7777

RUN useradd -ms /bin/bash debugger
RUN echo "debugger:ge123" | chpasswd

CMD ["/usr/sbin/sshd" "-D" "&"]

RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc


+ 2
- 2
scripts/env/ge_env.sh View File

@@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd)

DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/}
DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_}
DOCKER_IMAGE_TAG=ge_build_env.1.0.6
DOCKER_IMAGE_TAG=ge_build_env.1.0.9
DOCKER_IAMGE_NAME=joycode2art/turing
DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG}

@@ -61,7 +61,7 @@ function enter_docker_env(){
if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then
echo "please run 'ge env --pull' to download images first!"
elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then
$docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir}
$docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir}
elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then
$docker_cmd start ${DOCKER_BUILD_ENV_NAME}
$docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir}


+ 1
- 0
tests/depends/cce/CMakeLists.txt View File

@@ -60,6 +60,7 @@ set(SRCS
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc"
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc"
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc"


+ 0
- 13
tests/framework/CMakeLists.txt View File

@@ -17,16 +17,3 @@ include(cmake/graphengine.cmake)
add_subdirectory(easy_graph)
add_subdirectory(ge_graph_dsl)
add_subdirectory(ge_running_env)

file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS
"utils/*.cc"
)

add_library(framework STATIC ${UTILS_SRC})

target_include_directories(framework
PUBLIC utils/
)

set_target_properties(framework PROPERTIES CXX_STANDARD 11)
target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env)

+ 19
- 3
tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h View File

@@ -26,16 +26,32 @@ EG_NS_BEGIN

////////////////////////////////////////////////////////////////
namespace detail {
template<typename GRAPH_BUILDER>
template <typename GRAPH_BUILDER>
Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) {
GraphBuilder builder(name);
builderInDSL(builder);
return std::move(*builder);
}

struct GraphDefiner {
GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) {
name = specifiedName ? specifiedName : defaultName;
}

template <typename USER_BUILDER>
auto operator|(USER_BUILDER &&userBuilder) {
GraphBuilder graphBuilder{name};
std::forward<USER_BUILDER>(userBuilder)(graphBuilder);
return *graphBuilder;
}

private:
const char *name;
};

} // namespace detail

#define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__)
#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER)
#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER)
#define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__
#define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__
#define CHAIN(...) DATA_CHAIN(__VA_ARGS__)


+ 6
- 2
tests/framework/easy_graph/src/layout/graph_layout.cc View File

@@ -16,10 +16,15 @@

#include "easy_graph/layout/graph_layout.h"
#include "easy_graph/layout/layout_executor.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h"
#include "easy_graph/graph/graph.h"

EG_NS_BEGIN

namespace {
GraphEasyExecutor default_executor;
}

void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) {
this->executor_ = &executor;
options_ = opts;
@@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) {

Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) {
const LayoutOption *options = opts ? opts : this->options_;
if (!executor_)
return EG_UNIMPLEMENTED;
if (!executor_) return static_cast<LayoutExecutor &>(default_executor).Layout(graph, options);
return executor_->Layout(graph, options);
}



+ 37
- 0
tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef D52AA06185E34BBFB714FFBCDAB0D53A
#define D52AA06185E34BBFB714FFBCDAB0D53A

#include "ge_graph_dsl/ge.h"
#include <exception>
#include <string>

GE_NS_BEGIN

struct AssertError : std::exception {
AssertError(const char *file, int line, const std::string &info);

private:
const char *what() const noexcept override;

private:
std::string info;
};

GE_NS_END

#endif

+ 32
- 0
tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_31309AA0A4E44C009C22AD9351BF3410
#define INC_31309AA0A4E44C009C22AD9351BF3410

#include "ge_graph_dsl/ge.h"
#include "graph/compute_graph.h"

GE_NS_BEGIN

using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>;
struct CheckUtils {
static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun);
static void init();
};

GE_NS_END

#endif

tests/framework/utils/builder/tensor_builder_utils.cc → tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h View File

@@ -1,17 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensor_builder_utils.h"
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef C8B32320BD4943D588594B82FFBF2685
#define C8B32320BD4943D588594B82FFBF2685

#include <vector>
#include <string>
#include "ge_graph_dsl/ge.h"

GE_NS_BEGIN

struct FilterScopeGuard {
FilterScopeGuard(const std::vector<std::string> &);
~FilterScopeGuard();
};

GE_NS_END

#endif

+ 59
- 0
tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h View File

@@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6
#define AD954C4ADF5B44F5B1CC8BCD72EE9ED6

#include "ge_graph_dsl/ge.h"
#include "ge_graph_dsl/assert/check_utils.h"
#include "ge_graph_dsl/assert/assert_error.h"
#include "ge_graph_dsl/assert/filter_scope_guard.h"

GE_NS_BEGIN

#ifdef GTEST_MESSAGE_AT_
#define GRAPH_CHECK_MESSAGE(file, line, message) \
GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure)
#elif
#define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message)
#endif

namespace detail {
struct GraphAssert {
GraphAssert(const char *file, unsigned int line, const std::string &phase_id)
: file_(file), line_(line), phase_id_(phase_id) {}

void operator|(const ::GE_NS::GraphCheckFun &check_fun) {
bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun);
if (!ret) {
auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! ";
GRAPH_CHECK_MESSAGE(file_, line_, message.c_str());
}
}

private:
const char *file_;
unsigned int line_;
const std::string phase_id_;
};
} // namespace detail

#define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__});
#define CHECK_GRAPH(phase_id) \
::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph)

GE_NS_END

#endif

+ 2
- 4
tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h View File

@@ -33,14 +33,12 @@ struct OpDescCfg {
std::vector<int64_t> shape_;
};

OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW,
OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224})
: type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {}

protected:
OpType GetType() const {
return type_;
}
OpType GetType() const { return type_; }
OpType type_;
int in_cnt_;
int out_cnt_;


+ 26
- 0
tests/framework/ge_graph_dsl/src/assert/assert_error.cc View File

@@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_graph_dsl/assert/assert_error.h"

GE_NS_BEGIN

AssertError::AssertError(const char *file, int line, const std::string &info) {
this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info;
}

const char *AssertError::what() const noexcept { return info.c_str(); }

GE_NS_END

+ 34
- 0
tests/framework/ge_graph_dsl/src/assert/check_utils.cc View File

@@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_graph_dsl/assert/check_utils.h"
#include "graph/utils/dumper/ge_graph_dumper.h"
#include "ge_graph_default_checker.h"
#include "ge_graph_check_dumper.h"

GE_NS_BEGIN

bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) {
auto &dumper = dynamic_cast<GeGraphCheckDumper &>(GraphDumperRegistry::GetDumper());
return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun));
}

void CheckUtils::init() {
static GeGraphCheckDumper checkDumper;
GraphDumperRegistry::Register(checkDumper);
}

GE_NS_END

+ 31
- 0
tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc View File

@@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_graph_dsl/assert/filter_scope_guard.h"
#include "graph/utils/dumper/ge_graph_dumper.h"
#include "ge_dump_filter.h"

GE_NS_BEGIN

namespace {
GeDumpFilter &GetDumpFilter() { return dynamic_cast<GeDumpFilter &>(GraphDumperRegistry::GetDumper()); }
} // namespace

FilterScopeGuard::FilterScopeGuard(const std::vector<std::string> &filter) { GetDumpFilter().Update(filter); }

FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); }

GE_NS_END

+ 33
- 0
tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6
#define INC_4C6224E8F7474EF89B18CCB0E4B19FD6

#include <vector>
#include <string>
#include "ge_graph_dsl/ge.h"
#include "easy_graph/infra/keywords.h"

GE_NS_BEGIN

INTERFACE(GeDumpFilter) {
ABSTRACT(void Update(const std::vector<std::string> &));
ABSTRACT(void Reset());
};

GE_NS_END

#endif

+ 79
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc View File

@@ -0,0 +1,79 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_graph_check_dumper.h"
#include "graph/model.h"
#include "graph/buffer.h"
#include "graph/utils/graph_utils.h"
#include "ge_graph_default_checker.h"

GE_NS_BEGIN

GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); }

bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const {
auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix);
return (iter != suffixes_.end());
}

void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) {
if (!IsNeedDump(suffix)) {
return;
}
auto iter = buffers_.find(suffix);
if (iter != buffers_.end()) {
DumpGraph(graph, iter->second);
} else {
buffers_[suffix] = Buffer();
DumpGraph(graph, buffers_.at(suffix));
}
}

bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) {
auto iter = buffers_.find(checker.PhaseId());
if (iter == buffers_.end()) {
return false;
}
DoCheck(checker, iter->second);
return true;
}

void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) {
Model model("", "");
Model::Load(buffer.GetData(), buffer.GetSize(), model);
auto load_graph = model.GetGraph();
checker.Check(GraphUtils::GetComputeGraph(load_graph));
}

void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) {
Model model("", "");
buffer.clear();
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));
model.Save(buffer, true);
}

void GeGraphCheckDumper::Update(const std::vector<std::string> &new_suffixes_) {
suffixes_ = new_suffixes_;
buffers_.clear();
}

void GeGraphCheckDumper::Reset() {
static std::vector<std::string> default_suffixes_{"PreRunAfterBuild"};
suffixes_ = default_suffixes_;
buffers_.clear();
}

GE_NS_END

+ 49
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_8EFED0015C27464897BF64531355C810
#define INC_8EFED0015C27464897BF64531355C810

#include "ge_graph_dsl/ge.h"
#include "graph/utils/dumper/ge_graph_dumper.h"
#include "ge_dump_filter.h"
#include <string>

GE_NS_BEGIN

struct GeGraphChecker;

struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter {
GeGraphCheckDumper();
virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix);
bool CheckFor(const GeGraphChecker &checker);

private:
void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer);
void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer);

private:
void Update(const std::vector<std::string> &) override;
void Reset() override;
bool IsNeedDump(const std::string &suffix) const;

private:
std::map<std::string, ::GE_NS::Buffer> buffers_;
std::vector<std::string> suffixes_;
};

GE_NS_END

#endif

+ 32
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_5960A8F437324904BEE0690271258762
#define INC_5960A8F437324904BEE0690271258762

#include "ge_graph_dsl/ge.h"
#include "easy_graph/infra/keywords.h"
#include "graph/compute_graph.h"

GE_NS_BEGIN

INTERFACE(GeGraphChecker) {
ABSTRACT(const std::string &PhaseId() const);
ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const);
};

GE_NS_END

#endif

+ 28
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc View File

@@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_graph_default_checker.h"

GE_NS_BEGIN

GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun)
: phase_id_(phase_id), check_fun_(check_fun) {}

const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; }

void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); }

GE_NS_END

+ 41
- 0
tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef BCF4D96BE9FC48938DE7B7E93B551C54
#define BCF4D96BE9FC48938DE7B7E93B551C54

#include "ge_graph_dsl/ge.h"
#include "ge_graph_checker.h"
#include "graph/compute_graph.h"

GE_NS_BEGIN

using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>;

struct GeGraphDefaultChecker : GeGraphChecker {
GeGraphDefaultChecker(const std::string &, const GraphCheckFun &);

private:
const std::string &PhaseId() const override;
void Check(const ge::ComputeGraphPtr &graph) const override;

private:
const std::string phase_id_;
const GraphCheckFun check_fun_;
};

GE_NS_END

#endif

tests/framework/ge_graph_dsl/src/op_desc_cfg_box.cc → tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc View File


tests/framework/ge_graph_dsl/src/op_desc_cfg_repo.cc → tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc View File

@@ -23,15 +23,22 @@ GE_NS_BEGIN

namespace {

#define OP_CFG(optype, ...) \
{ \
optype, OpDescCfg { \
optype, __VA_ARGS__ \
} \
#define OP_CFG(optype, ...) \
{ \
optype, OpDescCfg { optype, __VA_ARGS__ } \
}

static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}),
OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}),
OP_CFG(VARIABLE, 1, 1)};
} // namespace


tests/framework/ge_graph_dsl/src/op_desc_ptr_box.cc → tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc View File

@@ -19,6 +19,4 @@

USING_GE_NS

OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const {
return op_;
}
OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; }

tests/framework/ge_graph_dsl/src/ge_graph_vistor.cc → tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc View File

@@ -36,17 +36,11 @@ GE_NS_BEGIN

GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {}

void GeGraphVisitor::reset(const ComputeGraphPtr &graph) {
build_graph_ = graph;
}
void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; }

Graph GeGraphVisitor::BuildGeGraph() const {
return GraphUtils::CreateGraphFromComputeGraph(build_graph_);
}
Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); }

ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const {
return build_graph_;
}
ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; }

Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) {
build_graph_->SetName(graph.GetName());

tests/framework/ge_graph_dsl/src/ge_subgraph_vistor.cc → tests/framework/ge_graph_dsl/src/vistor/ge_subgraph_vistor.cc View File


tests/framework/ge_graph_dsl/src/graph_dsl.cc → tests/framework/ge_graph_dsl/src/vistor/graph_dsl.cc View File


+ 1
- 1
tests/framework/ge_graph_dsl/tests/CMakeLists.txt View File

@@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE
)
set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17)

target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl)
target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl)

include(CTest)
enable_testing()

+ 129
- 0
tests/framework/ge_graph_dsl/tests/check_graph_test.cc View File

@@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "easy_graph/layout/graph_layout.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h"
#include "ge_graph_dsl/graph_dsl.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/dumper/ge_graph_dumper.h"
#include "framework/common/types.h"
#include "ge_graph_dsl/assert/graph_assert.h"
#include "graph/model.h"
#include "graph/buffer.h"

USING_GE_NS

class CheckGraphTest : public testing::Test {
private:
EG_NS::GraphEasyExecutor executor;

protected:
void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); }
void TearDown() {}
};

TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

DUMP_GRAPH_WHEN("after_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build");

CHECK_GRAPH(after_build) {
ASSERT_EQ(graph->GetName(), "g1");
ASSERT_EQ(graph->GetAllNodesSize(), 2);
};
}

TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };
DEF_GRAPH(g2) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD));
};

DUMP_GRAPH_WHEN("before_build", "after_build");

GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build");

CHECK_GRAPH(before_build) {
ASSERT_EQ(graph->GetName(), "g1");
ASSERT_EQ(graph->GetAllNodesSize(), 2);
};

CHECK_GRAPH(after_build) {
ASSERT_EQ(graph->GetName(), "g2");
ASSERT_EQ(graph->GetAllNodesSize(), 3);
};
}

TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };
DEF_GRAPH(g2) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD));
};

DUMP_GRAPH_WHEN("before_build")

GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build");

CHECK_GRAPH(before_build) {
ASSERT_EQ(graph->GetName(), "g2");
ASSERT_EQ(graph->GetAllNodesSize(), 3);
};
}

TEST_F(CheckGraphTest, test_check_phases_is_work) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

DUMP_GRAPH_WHEN("before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build");
auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {});
ASSERT_FALSE(ret);
}

TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

DUMP_GRAPH_WHEN("before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build");
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build");

CHECK_GRAPH(before_build) {
ASSERT_EQ(graph->GetName(), "g1");
ASSERT_EQ(graph->GetAllNodesSize(), 2);
};
}

TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) {
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };
auto ge_graph = ToGeGraph(g1);

ge::Model model("", "");
model.SetGraph(ge_graph);
Buffer buffer;
model.Save(buffer, true);

ge::Model loadModel("", "");
Model::Load(buffer.GetData(), buffer.GetSize(), loadModel);
auto load_graph = loadModel.GetGraph();

ASSERT_EQ(load_graph.GetName(), "g1");
ASSERT_EQ(load_graph.GetAllNodes().size(), 2);
}

+ 16
- 28
tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc View File

@@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test {
EG_NS::GraphEasyExecutor executor;

protected:
void SetUp() {
EG_NS::GraphLayout::GetInstance().Config(executor, nullptr);
}
void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); }

void TearDown() {}
};

TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) {
DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD));
});
DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

auto geGraph = ToGeGraph(g1);
auto computeGraph = ToComputeGraph(g1);
@@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) {
}

TEST_F(GraphDslTest, test_build_graph_with_name) {
DEF_GRAPH(g1, "sample_graph") {
CHAIN(NODE("data1", DATA)->NODE("add", ADD));
});
DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

auto geGraph = ToGeGraph(g1);

@@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) {
auto data = std::make_shared<OpDesc>("data1", DATA);
auto add = std::make_shared<OpDesc>("Add", ADD);
CHAIN(NODE(data)->NODE(add));
});
};

auto geGraph = ToGeGraph(g1);

@@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) {
auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1);
auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1);
CHAIN(NODE("data1", datCfg)->NODE("add", addCfg));
});
};

auto geGraph = ToGeGraph(g1);

@@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) {
}

TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) {
DEF_GRAPH(g1) {
CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1)));
});
DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); };

auto geGraph = ToGeGraph(g1);

@@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) {
}

TEST_F(GraphDslTest, test_build_from_control_chain) {
DEF_GRAPH(g1) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
});
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

auto geGraph = ToGeGraph(g1);

@@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) {
}

TEST_F(GraphDslTest, test_build_from_data_chain) {
DEF_GRAPH(g1) {
DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
});
DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); };

auto geGraph = ToGeGraph(g1);

@@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) {
DEF_GRAPH(g1) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add"));
});
};

auto geGraph = ToGeGraph(g1);

@@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) {
DEF_GRAPH(g1) {
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add"));
});
};

auto geGraph = ToGeGraph(g1);

@@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) {
DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data2", DATA)->NODE("add"));
});
};

auto geGraph = ToGeGraph(g1);

@@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) {
->NODE("Add4")
->NODE("Add5")
->NODE("net_output", NETOUTPUT));
});
};

auto geGraph = ToGeGraph(g1);

@@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) {
CHAIN(NODE("add2")->NODE("net_output"));
CHAIN(NODE("add3")->NODE("net_output"));
CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3"));
});
};

auto geGraph = ToGeGraph(g1);

@@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) {
DEF_GRAPH(sub_1) {
CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT));
CHAIN(NODE("const_5", CONSTANTOP)->NODE("less"));
});
};

DEF_GRAPH(sub_2) {
CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT));
CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul"));
});
};

DEF_GRAPH(g1) {
CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT));
CHAIN(NODE("data_i", DATA)->NODE("while"));
});
};

sub_1.Layout();
sub_2.Layout();


+ 6
- 0
tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc View File

@@ -30,5 +30,11 @@ REGISTER_OPTYPE_DEFINE(MUL, "Mul");
REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput");
REGISTER_OPTYPE_DEFINE(ADD, "Add");
REGISTER_OPTYPE_DEFINE(WHILE, "While");
REGISTER_OPTYPE_DEFINE(ENTER, "Enter");
REGISTER_OPTYPE_DEFINE(MERGE, "Merge");
REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond");
REGISTER_OPTYPE_DEFINE(SWITCH, "Switch");
REGISTER_OPTYPE_DEFINE(EXIT, "Exit");
REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration");

GE_NS_END

tests/framework/utils/builder/tensor_builder_utils.h → tests/framework/ge_graph_dsl/tests/test_main.cc View File

@@ -1,22 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H
#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H
class tensor_builder_utils {};
#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "ge_graph_dsl/assert/check_utils.h"

int main(int argc, char **argv) {
::GE_NS::CheckUtils::init();
testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}

+ 0
- 48
tests/framework/utils/builder/graph_builder_utils.cc View File

@@ -1,48 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph_builder_utils.h"
#include "inc/external/graph/operator.h"
#include "inc/external/graph/operator_factory.h"
#include "graph/utils/graph_utils.h"

namespace ge {
namespace st {
NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format, DataType data_type, std::vector<int64_t> shape) {
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape(std::move(shape)));
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);

auto op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < in_cnt; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (int i = 0; i < out_cnt; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}

return graph_->AddNode(op_desc);
}
void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) {
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
}
void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) {
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
}
} // namespace st
} // namespace ge

+ 0
- 55
tests/framework/utils/builder/graph_builder_utils.h View File

@@ -1,55 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H
#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H

#include <string>
#include <vector>

#include "graph/compute_graph.h"
#include "graph/utils/graph_utils.h"
#include "graph/graph.h"
#include "graph/node.h"

namespace ge {
namespace st {
class ComputeGraphBuilder {
public:
explicit ComputeGraphBuilder(const std::string &name) {
graph_ = std::make_shared<ComputeGraph>(name);
}
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT,
std::vector<int64_t> shape = {1, 1, 224, 224});
void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx);
void AddControlEdge(NodePtr &src_node, NodePtr &dst_node);
ComputeGraphPtr GetComputeGraph() {
graph_->TopologicalSorting();
return graph_;
}
Graph GetGraph() {
graph_->TopologicalSorting();
return GraphUtils::CreateGraphFromComputeGraph(graph_);
}

private:
ComputeGraphPtr graph_;
};
} // namespace st
} // namespace ge

#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H

+ 1
- 1
tests/st/testcase/CMakeLists.txt View File

@@ -8,7 +8,7 @@ target_include_directories(graph_engine_test

set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17)

target_link_libraries(graph_engine_test PRIVATE gtest framework)
target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env)

include(CTest)
enable_testing()

+ 49
- 78
tests/st/testcase/test_framework_dummy.cc View File

@@ -15,23 +15,12 @@
*/

#include <gtest/gtest.h>
#include <map>
#include "external/ge/ge_api.h"
#include "ge_running_env/fake_engine.h"
#include "graph/debug/ge_attr_define.h"
#include "framework/common/types.h"

#include "builder/graph_builder_utils.h"
#include "ge_running_env/ge_running_env_faker.h"

#include "graph/operator_reg.h"
#include "graph/operator.h"
#define protected public
#define private public
#include "graph/utils/op_desc_utils.h"
#include "ge_graph_dsl/graph_dsl.h"
#undef protected
#undef private
#include "ge_graph_dsl/assert/graph_assert.h"

using namespace std;
using namespace ge;
@@ -57,76 +46,58 @@ namespace {
*
**/
Graph BuildV1ControlFlowGraph() {
// build graph
st::ComputeGraphBuilder graphBuilder("g1");
auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1);
auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1);
ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1");
auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1);
auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1);
auto less = graphBuilder.AddNode("less", LESS, 2, 1);
auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL);
auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2);
auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1);
auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1);
auto add = graphBuilder.AddNode("add", ADD, 2, 1);
auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1);

auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1);
auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1);
ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1");
auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1);
auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2);
auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1);
auto mul = graphBuilder.AddNode("mul", MUL, 2, 1);
auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1);
auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1);
auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2);
// i = i+1
graphBuilder.AddDataEdge(data_i, 0, enter_i, 0);
graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0);
graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1);
graphBuilder.AddDataEdge(merge_i, 0, less, 0);
graphBuilder.AddDataEdge(const_5, 0, less, 1);
graphBuilder.AddDataEdge(less, 0, loopcond, 0);
graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1);
graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0);
graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0);
graphBuilder.AddDataEdge(switch_i, 1, add, 0);
graphBuilder.AddDataEdge(const_1, 0, add, 1);
graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0);
graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1);
// a=a*2
graphBuilder.AddDataEdge(data_a, 0, enter_a, 0);
graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0);
graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1);
graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1);
graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0);
graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0);
graphBuilder.AddDataEdge(switch_a, 1, mul, 0);
graphBuilder.AddDataEdge(const_2, 0, mul, 1);
graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0);
graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0);
// set const weight
int64_t dims_size = 1;
vector<int64_t> data_vec = {5};
for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; });
vector<int32_t> data_value_vec(dims_size, 1);
GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32);
GeTensorPtr data_tensor =
make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t));
OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor);
OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor);
OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor);
GeTensorPtr data_tensor = make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(),
data_value_vec.size() * sizeof(int32_t));

return graphBuilder.GetGraph();
auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1");
auto const_op = OP_CFG(CONSTANT).Weight(data_tensor);

DEF_GRAPH(g1) {
CHAIN(NODE("data_i", DATA)
->NODE("enter_i", enter)
->EDGE(0, 0)
->NODE("merge_i", MERGE)
->NODE("less", LESS)
->NODE("loopcond", LOOPCOND));
CHAIN(NODE("const_1", const_op)
->EDGE(0, 1)
->NODE("add", ADD)
->NODE("iteration_i", NEXTITERATION)
->EDGE(0, 1)
->NODE("merge_i"));
CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less"));
CHAIN(NODE("loopcond")
->EDGE(0, 1)
->NODE("switch_i", SWITCH)
->EDGE(0, 0)
->NODE("exit_i", EXIT)
->EDGE(0, 1)
->NODE("netoutput", NETOUTPUT));
CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add"));
CHAIN(NODE("data_a", DATA)
->NODE("enter_a", enter)
->NODE("merge_a", MERGE)
->NODE("switch_a", SWITCH)
->NODE("exit_a", EXIT)
->EDGE(0, 0)
->NODE("netoutput"));
CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a"));
CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL));
CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a"));
};
return ToGeGraph(g1);
}
} // namespace
class FrameworkTest : public testing::Test {
protected:
GeRunningEnvFaker ge_env;
void SetUp() { ge_env.InstallDefault(); }
void TearDown() {}
GeRunningEnvFaker ge_env;
};

/// data data
@@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) {
DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data2", DATA)->NODE("add"));
});
};

auto graph = ToGeGraph(g1);
// new session & add graph
map<AscendString, AscendString> options;
Session session(options);
auto ret = session.AddGraph(1, graph, options);
EXPECT_EQ(ret, SUCCESS);
// build input tensor
session.AddGraph(1, ToGeGraph(g1), options);
std::vector<InputTensorInfo> inputs;
// build_graph through session
ret = session.BuildGraph(1, inputs);
auto ret = session.BuildGraph(1, inputs);
EXPECT_EQ(ret, SUCCESS);
CHECK_GRAPH(PreRunAfterBuild) {
ASSERT_EQ(graph->GetName(), "g1_1");
ASSERT_EQ(graph->GetAllNodesSize(), 4);
};
}

/** data a = 2;


+ 4
- 16
tests/st/testcase/test_ge_opt_info.cc View File

@@ -15,24 +15,12 @@
*/

#include <gtest/gtest.h>
#include "easy_graph/graph/box.h"
#include "easy_graph/graph/node.h"
#include "external/ge/ge_api.h"
#include "easy_graph/builder/graph_dsl.h"
#include "easy_graph/builder/box_builder.h"
#include "easy_graph/layout/graph_layout.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h"
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h"
#include "graph/graph.h"
#include "graph/compute_graph.h"
#include "framework/common/types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_local_context.h"
#include "ge_graph_dsl/graph_dsl.h"
#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h"
#define protected public
#define private public
#include "ge_opt_info/ge_opt_info.h"
#undef private
#undef protected

namespace ge {
class STEST_opt_info : public testing::Test {
@@ -52,7 +40,7 @@ TEST_F(STEST_opt_info, get_opt_info_all) {
DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data2", DATA)->NODE("add"));
});
};

auto graph = ToGeGraph(g1);

@@ -95,7 +83,7 @@ TEST_F(STEST_opt_info, get_opt_info_success) {
DEF_GRAPH(g1) {
CHAIN(NODE("data1", DATA)->NODE("add", ADD));
CHAIN(NODE("data2", DATA)->NODE("add"));
});
};

auto graph = ToGeGraph(g1);



+ 2
- 2
tests/st/testcase/test_main.cc View File

@@ -15,9 +15,8 @@
*/

#include <gtest/gtest.h>

#include "common/debug/log.h"
#include "external/ge/ge_api.h"
#include "ge_graph_dsl/assert/check_utils.h"
#include "ge_running_env/include/ge_running_env/ge_running_env_faker.h"

using namespace std;
@@ -31,6 +30,7 @@ int main(int argc, char **argv) {
std::cout << "ge init failed , ret code:" << init_status << endl;
}
GeRunningEnvFaker::BackupEnv();
CheckUtils::init();
testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;


+ 1
- 0
tests/ut/common/graph/CMakeLists.txt View File

@@ -90,6 +90,7 @@ set(SRC_FILES
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc"
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc"
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc"


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -102,6 +102,7 @@ set(GRAPH_SRC_FILES
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc"
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc"
"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc"


+ 3
- 1
tests/ut/ge/graph/build/task_generator_unittest.cc View File

@@ -116,7 +116,9 @@ TEST_F(UtestTaskGeneratorTest, FindLastBpFromBpNode) {
TaskGenerator task_generator(nullptr, 0);
auto net_output = graph->FindNode("Node_Output");
// netoutput has no data input, return default value 0
EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output), 0);
uint32_t bp_index = 0;
EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output, bp_index), 0);
EXPECT_EQ(bp_index, 2);
}

TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) {


+ 1
- 1
tests/ut/ge/graph/passes/addn_pass_unittest.cc View File

@@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) {
AddNPass *addn_pass = nullptr;
NamesToPass names_to_pass;
names_to_pass.emplace_back("Test", addn_pass);
EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
EXPECT_NE(pass.Run(names_to_pass), SUCCESS);
}

TEST(UtestGraphPassesAddnPass, null_graph) {


+ 462
- 11
tests/ut/ge/graph/passes/base_pass_unittest.cc View File

@@ -17,7 +17,6 @@
#include <iostream>
#include <map>
#include <set>
#include <string>
#include <vector>

#include "gtest/gtest.h"
@@ -26,8 +25,6 @@
#include "graph/passes/base_pass.h"
#undef protected

#include "external/graph/ge_error_codes.h"
#include "framework/common/ge_inner_error_codes.h"
#include "framework/common/types.h"
#include "graph/node.h"
#include "graph/utils/graph_utils.h"
@@ -67,6 +64,54 @@ class UtestTestPass : public BaseNodePass {
names_to_add_repass_.erase(iter);
}
}

iter = names_to_add_repass_immediate_.find(node->GetName());
if (iter != names_to_add_repass_immediate_.end()) {
auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
for (const auto &node_name : iter->second) {
for (auto &node_re_pass : all_nodes) {
if (node_re_pass->GetName() == node_name) {
AddImmediateRePassNode(node_re_pass);
break;
}
}
}
if (!dead_loop_) {
names_to_add_repass_immediate_.erase(iter);
}
}

iter = names_to_add_suspend_.find(node->GetName());
if (iter != names_to_add_suspend_.end()) {
auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
for (const auto &node_name : iter->second) {
for (auto &node_re_pass : all_nodes) {
if (node_re_pass->GetName() == node_name) {
AddNodeSuspend(node_re_pass);
break;
}
}
}
if (!dead_loop_) {
names_to_add_suspend_.erase(iter);
}
}

iter = names_to_add_resume_.find(node->GetName());
if (iter != names_to_add_resume_.end()) {
auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
for (const auto &node_name : iter->second) {
for (auto &node_re_pass : all_nodes) {
if (node_re_pass->GetName() == node_name) {
AddNodeResume(node_re_pass);
break;
}
}
}
if (!dead_loop_) {
names_to_add_resume_.erase(iter);
}
}
// simulate infershape pass
if(node->GetType() == WHILE){
bool need_repass = false;
@@ -85,6 +130,20 @@ class UtestTestPass : public BaseNodePass {
}
return SUCCESS;
}

Status OnSuspendNodesLeaked() override {
// resume all node remain in suspend_nodes when leaked
auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr;
if (compute_graph == nullptr) {
return SUCCESS;
}

for (const auto &node_name : names_to_add_resume_onleaked_) {
auto node_to_resume = compute_graph->FindNode(node_name);
AddNodeResume(node_to_resume);
}
return SUCCESS;
}
void clear() { iter_nodes_.clear(); }
std::vector<NodePtr> GetIterNodes() { return iter_nodes_; }

@@ -94,12 +153,30 @@ class UtestTestPass : public BaseNodePass {
void AddDelNodeName(const std::string &iter_node, const std::string &del_node) {
names_to_add_del_[iter_node].insert(del_node);
}
void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) {
names_to_add_repass_immediate_[iter_node].insert(re_pass_node);
}

void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) {
names_to_add_suspend_[iter_node].insert(suspend_node);
}
void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) {
names_to_add_resume_[iter_node].insert(resume_node);
}
void AddResumeNodeNameOnLeaked(const std::string &resume_node) {
names_to_add_resume_onleaked_.insert(resume_node);
}

unsigned int GetRunTimes() { return run_times_; }

private:
std::vector<NodePtr> iter_nodes_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_del_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_immediate_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_suspend_;
std::map<std::string, std::unordered_set<std::string>> names_to_add_resume_;
std::unordered_set<std::string> names_to_add_resume_onleaked_;
bool dead_loop_;
unsigned int run_times_;
};
@@ -200,6 +277,26 @@ ComputeGraphPtr BuildGraph3() {
return builder.GetGraph();
}

/// cast1--shape1
/// /
/// data1
/// \
/// transdata1--shape2
ComputeGraphPtr BuildGraph4() {
auto builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1", DATA, 0, 1);
auto cast1 = builder.AddNode("cast1", CAST, 1, 1);
auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1);
auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1);
auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1);

builder.AddDataEdge(data1, 0, cast1, 0);
builder.AddDataEdge(data1, 0, transdata1, 0);
builder.AddDataEdge(cast1, 0, shape1, 0);
builder.AddDataEdge(transdata1, 0, shape2, 0);
return builder.GetGraph();
}

void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::string>> &nodes_layers) {
std::unordered_set<std::string> layer_nodes;
size_t layer_index = 0;
@@ -509,15 +606,369 @@ ComputeGraphPtr BuildWhileGraph1() {
}

TEST_F(UTESTGraphPassesBasePass, while_infershape) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

auto graph = BuildWhileGraph1();
auto ge_pass = GEPass(graph);
auto while_node = graph->FindNode("while");
EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
auto graph = BuildWhileGraph1();
auto ge_pass = GEPass(graph);
auto while_node = graph->FindNode("while");
EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// repass pre_node immediately
test_pass->AddRePassImmediateNodeName("reshape1", "add1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);

EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "add1", "sum1"});
CheckIterOrder(test_pass, layers);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// repass cur_node immediately
test_pass->AddRePassImmediateNodeName("reshape1", "reshape1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);

EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(test_pass, layers);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// repass next_node immediately
test_pass->AddRePassImmediateNodeName("reshape1", "sum1");
// repass node after next_node immediately
test_pass->AddRePassImmediateNodeName("add1", "sum1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);

EXPECT_EQ(test_pass->GetIterNodes().size(), 8);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(test_pass, layers);
}
/**
* A->B->C
* if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order
* when C resuem A, A will pass again.
*/
TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// add1->reshape1->sum1
test_pass->AddSuspendNodeName("reshape1", "add1");
test_pass->AddResumeNodeName("sum1", "add1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
layers.push_back({"add1"});
CheckIterOrder(test_pass, layers);
}

/**
* A->B->C
* if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order
* when B resuem A, A will pass again.
*/
TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// add1->reshape1->sum1
test_pass->AddSuspendNodeName("reshape1", "add1");
test_pass->AddResumeNodeName("reshape1", "add1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1", "add1"});
CheckIterOrder(test_pass, layers);
}

/**
* A->B->C
* if node B resume C(which is not suspended), it is a useless operation, C will not pass.
*/
TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// add1->reshape1->sum1
test_pass->AddResumeNodeName("reshape1", "sum1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 8);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(test_pass, layers);
}

/**
* A->B->C
* if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order
* because nobody resume it ,which means A is a leaked node, so return fail
*/
TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));
// suspend pre_node immediately
test_pass.AddSuspendNodeName("reshape1", "add1");
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR);
}

/**
* A->B->C
* if node B suspend its pre_node A, it is a useless operation,
* so iter_order should follow normal order
* resume A on leaked, which means A will pass again
*/
TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) {
auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("reshape1", "add1");
test_pass->AddResumeNodeNameOnLeaked("add1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
layers.push_back({"add1"});
CheckIterOrder(test_pass, layers);
}


/// cast1--shape1
/// /
/// data1
/// \
/// transdata1--shape2
/**
* suspend cur node
* cast1 suspend itself, shape2 resume cast1
* iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1
*/
TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("cast1", "cast1");
test_pass->AddResumeNodeName("shape2", "cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1"});
layers.push_back({"cast1","transdata1"});
layers.push_back({"shape2"});
layers.push_back({"cast1", "shape1"});
CheckIterOrder(test_pass, layers);
}
/**
* suspend cur node
* cast1 suspend itself, then resume cast1
* iter order follows : data1; cast1,cast1,transdata1; shape2; shape1.
*/
TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("cast1", "cast1");
test_pass->AddResumeNodeName("cast1", "cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1"});
layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"});
CheckIterOrder(test_pass, layers);
}
/**
* suspend cur node
* cast1 suspend itself, then resume cast1 on leaked
* iter order follows : data1; cast1,cast1,transdata1; shape2; shape1.
*/
TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("cast1", "cast1");
test_pass->AddResumeNodeNameOnLeaked("cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1"});
layers.push_back({"cast1","transdata1", "shape2"});
layers.push_back({"cast1","shape1"});
CheckIterOrder(test_pass, layers);
}
/**
* suspend next node
* data1 suspend cast1, then resume cast1 on leaked
* iter order follows : data1; transdata1, shape2; cast1, shape1.
*/
TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("data1", "cast1");
test_pass->AddResumeNodeNameOnLeaked("cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
EXPECT_EQ(test_pass->GetIterNodes().size(), 5);
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1"});
layers.push_back({"transdata1", "shape2"});
layers.push_back({"cast1","shape1"});
CheckIterOrder(test_pass, layers);
}

/**
* suspend next node
* data1 suspend cast1, nobody resume it
* iter order follows : data1; transdata1, shape2;
* run ret is failed ,because node leaked
*/
TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) {
auto graph = BuildGraph4();
auto ge_pass = GEPass(graph);
auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
// suspend pre_node immediately
test_pass->AddSuspendNodeName("data1", "cast1");
EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR);
EXPECT_EQ(test_pass->GetIterNodes().size(), 3);
}


TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass pre_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "add1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "add1", "sum1"});
CheckIterOrder(&test_pass, layers);
}
/// sum1
/// / \.
/// / \.
/// / \.
/// reshape1 addn1
/// | c |
/// add1 <--- shape1
/// / \ |
/// | | |
/// data1 const1 const2
TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass cur_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "reshape1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}

TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass next_node immediately
test_pass.AddRePassImmediateNodeName("reshape1", "sum1");
// repass node after next_node immediately
test_pass.AddRePassImmediateNodeName("add1", "sum1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}
/*
TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) {
NamesToPass names_to_pass;
auto test_pass = UtestTestPass();
names_to_pass.push_back(std::make_pair("test", &test_pass));

// repass next_node immediately
test_pass.AddRePassNodeName("reshape1", "sum1");
// repass node after next_node immediately
test_pass.AddRePassNodeName("add1", "sum1");

auto graph = BuildGraph2();
auto ge_pass = GEPass(graph);
EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo
std::vector<std::unordered_set<std::string>> layers;
layers.push_back({"data1", "const1", "const2"});
layers.push_back({"shape1"});
layers.push_back({"add1", "addn1"});
layers.push_back({"reshape1", "sum1"});
CheckIterOrder(&test_pass, layers);
}*/
} // namespace ge

+ 45
- 0
tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc View File

@@ -293,6 +293,9 @@ class AddKernel : public Kernel {
} else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) {
vector<int32_t> data_vec;
auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize();
if (input[0]->GetTensorDesc().GetShape().IsScalar()) {
data_num = 1;
}
auto x1_data = reinterpret_cast<const int32_t *>(input[0]->GetData().data());
auto x2_data = reinterpret_cast<const int32_t *>(input[1]->GetData().data());
for (size_t i = 0; i < data_num; i++) {
@@ -410,6 +413,48 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave
EXPECT_EQ(unknown_target_value_range, output_value_range);
}

TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_ScalarOutput) {
// shape --- add --- sqrt
// constant /
auto graph = std::make_shared<ComputeGraph>("test_graph");
vector<int32_t> data_vec = {2};
GeTensorDesc const_td(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
GeTensorPtr const_tensor = std::make_shared<ge::GeTensor>(const_td, (uint8_t *)data_vec.data(), sizeof(int32_t));
auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant");
const_op_desc->AddOutputDesc(const_td);
EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS);
auto const_node = graph->AddNode(const_op_desc);

GeTensorDesc shape_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
std::vector<std::pair<int64_t, int64_t>> known_value_range = {make_pair(1, 100)};
shape_td.SetValueRange(known_value_range);
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape");
shape_op_desc->AddOutputDesc(shape_td);
auto shape_node = graph->AddNode(shape_op_desc);

GeTensorDesc add_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add");
add_op_desc->AddInputDesc(shape_td);
add_op_desc->AddInputDesc(const_td);
add_op_desc->AddOutputDesc(add_td);
auto add_node = graph->AddNode(add_op_desc);

ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1));

InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS);

auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> out_value_range;
output_0_desc.GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 1);

std::vector<int64_t> target_value_range = {3, 102};
std::vector<int64_t> output_value_range = {out_value_range[0].first, out_value_range[0].second};
EXPECT_EQ(output_value_range, target_value_range);
}

TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) {
// shape --- add --- sqrt
// constant /


+ 104
- 6
tests/ut/ge/graph/passes/infershape_pass_unittest.cc View File

@@ -29,13 +29,77 @@
using namespace std;
using namespace testing;
namespace ge {
namespace {
// do nothing stub infer_func
const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
// infer from input to output stub infer_func (input size == output size)
const auto stub_mapping_func = [](Operator &op) {
size_t in_num = op.GetInputsSize();
for (size_t i = 0; i < in_num; ++i) {
auto in_desc = op.GetInputDesc(i);
auto out_desc = op.GetOutputDesc(i);
out_desc.SetShape(in_desc.GetShape());
out_desc.SetDataType(in_desc.GetDataType());
op.UpdateOutputDesc(out_desc.GetName(), out_desc);
}
return GRAPH_SUCCESS;
};
// merge infer_func

// while infer_func
const auto while_infer_func = [](Operator &op) {
size_t in_num = op.GetInputsSize();
size_t out_num = op.GetOutputsSize();
if (in_num != out_num) {
return GRAPH_FAILED;
}
bool need_infer_again = false;
for (size_t i = 0; i < in_num; ++i) {
auto in_desc = op.GetDynamicInputDesc("input", i);
auto out_desc = op.GetDynamicOutputDesc("output", i);
auto data_shape = in_desc.GetShape();
auto out_shape = out_desc.GetShape();
if(out_shape.GetDims() == DUMMY_SHAPE){
return GRAPH_SUCCESS;
}
// check datatype between output and input
if (in_desc.GetDataType() != out_desc.GetDataType()) {
return GRAPH_FAILED;
}

if (data_shape.GetDims() != out_shape.GetDims()) {
need_infer_again = true;
if (data_shape.GetDimNum() != out_shape.GetDimNum()) {
in_desc.SetUnknownDimNumShape();
} else {
size_t data_dim_num = data_shape.GetDimNum();
std::vector<std::pair<int64_t, int64_t>> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)};
for (size_t j = 0; j < data_dim_num; ++j) {
if (data_shape.GetDim(j) != out_shape.GetDim(j)) {
data_shape.SetDim(j, UNKNOWN_DIM);
}
if (data_shape.GetDim(j) != UNKNOWN_DIM) {
data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j));
}
}
in_desc.SetShape(data_shape);
in_desc.SetShapeRange(data_shape_range);
}
op.UpdateDynamicOutputDesc("output", i, in_desc);
op.UpdateDynamicInputDesc("input", i, in_desc);
}
}
return need_infer_again ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS;
};
}
class UtestGraphInfershapePass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num,
std::function<graphStatus(Operator &)> infer_func = stub_func) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
@@ -61,14 +125,11 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
op_desc->AddInferFunc(stub_func);
op_desc->AddInferFormatFunc(stub_func);
op_desc->AddVerifierFunc(stub_func);

op_desc->AddInferFunc(infer_func);
return graph.AddNode(op_desc);
}

/*
TEST_F(UtestGraphInfershapePass, infershape_pass_failed) {
GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
string type = "AddN";
@@ -82,6 +143,7 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) {
InferShapePass infershape_pass;
EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED);
}
*/

TEST_F(UtestGraphInfershapePass, delete_need_infer_again) {
auto graph = std::make_shared<ComputeGraph>("test");
@@ -94,7 +156,43 @@ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) {
infershape_pass.options_[kOptimizeAfterSubGraph] = "yes";
EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS);
}
TEST_F(UtestGraphInfershapePass, infer_from_pre_to_next) {
/*
* cast->shape
*/
auto graph = std::make_shared<ComputeGraph>("test_infer_shape");
auto data1 = CreateNode(*graph, "dataq", DATA, 0, 1);
auto cast1 = CreateNode(*graph, "cast1", CAST, 1, 1, stub_mapping_func);
auto cast_in_desc = cast1->GetOpDesc()->MutableInputDesc(0);
cast_in_desc->SetShape(GeShape({1,2,3}));
cast_in_desc->SetDataType(DT_INT32);
auto transdata1 = CreateNode(*graph, "transdata1", TRANSDATA, 1, 1, stub_mapping_func);
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), cast1->GetInDataAnchor(0));
GraphUtils::AddEdge(cast1->GetOutDataAnchor(0), transdata1->GetInDataAnchor(0));

// check before infer cast1
auto cast_before = graph->FindNode("cast1");
vector<int64_t> expect_cast1_shape_dim = {1,2,3};
auto real_cast1_before_shape_dim = cast_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims();
auto transdata1_before = graph->FindNode("transdata1");
vector<int64_t> expect_transdata1_shape_dim = {};
auto real_transdata1_before_shape_dim = transdata1_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims();
EXPECT_EQ(real_cast1_before_shape_dim, expect_cast1_shape_dim);
EXPECT_EQ(real_transdata1_before_shape_dim, expect_transdata1_shape_dim);
// run infershape pass
InferShapePass infer_shape_pass;
infer_shape_pass.Run(cast_before);
// check cast1 add transdata1 to repass_immediately
infer_shape_pass.GetNodesNeedRePassImmediately();
EXPECT_TRUE(!infer_shape_pass.GetNodesNeedRePassImmediately().empty());
// check transdata input_shape & datatype after infer
auto transdata1_after = graph->FindNode("transdata1");
auto transdata1_opdesc = transdata1_before->GetOpDesc();
auto real_transdata1_after_shape_dim = transdata1_opdesc->GetInputDesc(0).GetShape().GetDims();
EXPECT_EQ(real_transdata1_after_shape_dim, expect_cast1_shape_dim);
auto transdata1_datatype_after = transdata1_opdesc->GetInputDesc(0).GetDataType();
EXPECT_EQ(transdata1_datatype_after, DT_INT32);
}
TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) {
/*******************************************************************************
* Exit Identify


+ 28
- 0
tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc View File

@@ -103,4 +103,32 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute) {
context.callback_manager->callback_queue_.Push(eof_entry);
ASSERT_EQ(executor.Execute(args), SUCCESS);
}

TEST_F(UtestHybridModelAsyncExecutor, test_PrepareInputs) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
GeModelPtr ge_sub_model = make_shared<GeModel>();
HybridModel hybrid_model(ge_root_model);
HybridModelAsyncExecutor executor(&hybrid_model);
GeTensorDescPtr tensor_desc = make_shared<GeTensorDesc>(GeShape({-1, 16, 16, 3}));
tensor_desc->SetShapeRange({{1, 256}, {16, 16}, {16, 16}, {3, 3}});
executor.input_tensor_desc_.insert({0, tensor_desc});
executor.device_id_ = 0;
executor.input_sizes_.insert({0, -1});
executor.is_input_dynamic_.push_back(true);

unique_ptr<uint8_t[]> data_buf(new (std::nothrow)uint8_t[3072]);
InputData input_data;
input_data.blobs.push_back(DataBuffer(data_buf.get(), 3072, false));
input_data.shapes.push_back({1, 16, 16, 3});
HybridModelExecutor::ExecuteArgs args;

auto ret = executor.PrepareInputs(input_data, args);
ASSERT_EQ(ret, SUCCESS);
ASSERT_EQ(args.input_desc[0]->GetShape().ToString(), GeShape({1, 16, 16, 3}).ToString());
int64_t tensor_size = 0;
TensorUtils::GetSize(*(args.input_desc[0]), tensor_size);
ASSERT_EQ(tensor_size, 3104);
}
} // namespace ge

+ 12
- 0
tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc View File

@@ -27,6 +27,7 @@
#include "hybrid/executor/hybrid_model_executor.h"
#include "hybrid/executor/worker/execution_engine.h"
#include "hybrid/executor/subgraph_executor.h"
#include "hybrid/executor/worker/task_compile_engine.h"
#undef private
#undef protected

@@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test {
};
namespace {
const int kIntBase = 10;
class CompileNodeExecutor : public NodeExecutor {
public:
Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr<NodeTask> &task) const override {
return SUCCESS;
}
};
}

static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") {
auto op_desc = std::make_shared<ge::OpDesc>(name, type);
op_desc->SetStreamId(0);
@@ -128,4 +136,8 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
executor.InitCallback(node_state.get(), callback);
ExecutionEngine execution_engine;
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
CompileNodeExecutor node_executor;
node_item->node_executor = &node_executor;
EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS);
}

+ 1
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -153,6 +153,7 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) {
ge::AttrUtils::SetStr(op_desc, "compile_info_json", "json");
ge::AttrUtils::SetBool(op_desc, "support_dynamicshape", true);
ge::AttrUtils::SetInt(op_desc, "op_para_size", 1);
ge::AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF");
auto node = graph->AddNode(op_desc);

std::unique_ptr<NodeItem> node_item;


+ 1
- 0
tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc View File

@@ -87,6 +87,7 @@ TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) {
TEST_F(NodeExecutorTest, TestInitAndFinalize) {
auto &manager = NodeExecutorManager::GetInstance();
manager.FinalizeExecutors();
manager.FinalizeExecutors();
manager.EnsureInitialized();
manager.EnsureInitialized();
const NodeExecutor *executor = nullptr;


+ 20
- 1
tests/ut/ge/single_op/single_op_model_unittest.cc View File

@@ -311,7 +311,7 @@ TEST_F(UtestSingleOpModel, BuildTaskList) {
ASSERT_EQ(mem_task.LaunchKernel(0), SUCCESS);
}

TEST_F(UtestSingleOpModel, build_aicpu_task) {
TEST_F(UtestSingleOpModel, build_dynamic_task) {
ComputeGraphPtr graph = make_shared<ComputeGraph>("single_op");
GeModelPtr ge_model = make_shared<GeModel>();
ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));
@@ -321,6 +321,15 @@ TEST_F(UtestSingleOpModel, build_aicpu_task) {
domi::TaskDef *task_def = model_task_def->add_task();
task_def->set_type(RT_MODEL_TASK_KERNEL_EX);

domi::TaskDef *task_def2 = model_task_def->add_task();
task_def2->set_type(RT_MODEL_TASK_KERNEL);
domi::KernelDef *kernel_def = task_def2->mutable_kernel();
domi::KernelContext *context = kernel_def->mutable_context();
context->set_kernel_type(6); // ccKernelType::AI_CPU

domi::TaskDef *task_def3 = model_task_def->add_task();
task_def3->set_type(RT_MODEL_TASK_ALL_KERNEL);

string model_data_str = "123456789";
SingleOpModel model("model", model_data_str.c_str(), model_data_str.size());
std::mutex stream_mu;
@@ -329,8 +338,18 @@ TEST_F(UtestSingleOpModel, build_aicpu_task) {
DynamicSingleOp single_op(0, &stream_mu, stream);
model.model_helper_.model_ = ge_model;
auto op_desc = std::make_shared<ge::OpDesc>("add", "Add");
AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF");
std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
NodePtr node = graph->AddNode(op_desc);
model.op_list_[0] = node;
StreamResource *res = new (std::nothrow) StreamResource(1);

ASSERT_EQ(model.ParseTasks(), SUCCESS);
ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS);
model.tbe_tasks_.clear();
ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS);
model.aicpu_tasks_[0] = *task_def2;
model.BuildTaskListForDynamicOp(res, single_op);
}

+ 1
- 0
tests/ut/ge/single_op/single_op_task_unittest.cc View File

@@ -54,6 +54,7 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) {

auto graph = make_shared<ComputeGraph>("graph");
auto op_desc = make_shared<OpDesc>("Add", "Add");
AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF");
std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);


+ 7
- 0
third_party/fwkacllib/inc/external/runtime/rt_error_codes.h View File

@@ -38,6 +38,7 @@ static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callba
static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type
static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle
static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type
static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout
static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support
static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error
@@ -50,6 +51,7 @@ static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no eve
static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource
static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource
static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource
static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource
static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error
static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error
@@ -85,9 +87,14 @@ static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug
static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context
static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out
static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout
static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception
static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal
static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error
static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error
static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect
#ifdef __cplusplus
}


+ 4
- 4
third_party/fwkacllib/inc/runtime/base.h View File

@@ -156,7 +156,7 @@ RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtSt

/**
* @ingroup profiling_base
* @brief ts send keypoint for step info.
* @brief ts send keypoint profiler log.
*/
RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream);

@@ -206,7 +206,7 @@ RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCal

/**
* @ingroup dvrt_base
* @brief register callback for fail task
* @brief register callback for fail task
* @param [in] uniName unique register name, can't be null
* @param [in] callback fail task callback function
* @param [out] NA
@@ -345,11 +345,11 @@ RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream);
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream);
RTS_API rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream);

/**
* @ingroup dvrt_base
* @brief get current thread last stream id and task id
* @brief get current thread last stream id and task id
* @param [out] stream id and task id
* @param [in] null
* @return RT_ERROR_NONE for ok


+ 43
- 0
third_party/fwkacllib/inc/runtime/config.h View File

@@ -46,6 +46,12 @@ typedef enum tagRtChipType {
CHIP_END,
} rtChipType_t;

typedef enum tagRtAicpuScheType {
SCHEDULE_SOFTWARE = 0, /* Software Schedule */
SCHEDULE_SOFTWARE_OPT,
SCHEDULE_HARDWARE, /* HWTS Schedule */
} rtAicpuScheType;

typedef enum tagRtVersion {
VER_BEGIN = 0,
VER_NA = VER_BEGIN,
@@ -65,6 +71,7 @@ typedef enum tagRtPlatformType {
PLATFORM_LHISI_CS,
PLATFORM_DC,
PLATFORM_CLOUD_V2,
PLATFORM_LHISI_SD3403,
PLATFORM_END,
} rtPlatformType_t;

@@ -126,6 +133,11 @@ typedef struct tagRtPlatformConfig {
uint32_t platformConfig;
} rtPlatformConfig_t;

typedef enum tagRTTaskTimeoutType {
RT_TIMEOUT_TYPE_OP_WAIT = 0,
RT_TIMEOUT_TYPE_OP_EXECUTE,
} rtTaskTimeoutType_t;

/**
* @ingroup
* @brief get AI core count
@@ -184,6 +196,37 @@ RTS_API rtError_t rtMemGetL2Info(rtStream_t stream, void **ptr, uint32_t *size);
*/
RTS_API rtError_t rtGetRuntimeVersion(uint32_t *runtimeVersion);


/**
* @ingroup
* @brief get device feature ability by device id, such as task schedule ability.
* @param [in] deviceId
* @param [in] moduleType
* @param [in] featureType
* @param [out] value
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtGetDeviceCapability(int32_t deviceId, int32_t moduleType, int32_t featureType, int32_t *value);

/**
* @ingroup
* @brief set event wait task timeout time.
* @param [in] timeout
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtSetOpWaitTimeOut(uint32_t timeout);

/**
* @ingroup
* @brief set op execute task timeout time.
* @param [in] timeout
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtSetOpExecuteTimeOut(uint32_t timeout);

#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE)
}
#endif


+ 5
- 0
third_party/fwkacllib/inc/runtime/dev.h View File

@@ -63,6 +63,11 @@ typedef enum tagRtFeatureType {
FEATURE_TYPE_RSV
} rtFeatureType_t;

typedef enum tagRtDeviceFeatureType {
FEATURE_TYPE_SCHE,
FEATURE_TYPE_END,
} rtDeviceFeatureType_t;

typedef enum tagMemcpyInfo {
MEMCPY_INFO_SUPPORT_ZEROCOPY = 0,
MEMCPY_INFO_RSV


+ 35
- 2
third_party/fwkacllib/inc/runtime/event.h View File

@@ -23,12 +23,23 @@
extern "C" {
#endif

typedef enum rtEventWaitStatus {
EVENT_STATUS_COMPLETE = 0,
EVENT_STATUS_NOT_READY = 1,
EVENT_STATUS_MAX = 2,
} rtEventWaitStatus_t;

/**
* @ingroup event_flags
* @brief event op bit flags
*/
#define RT_EVENT_DEFAULT (0x00)
#define RT_EVENT_WITH_FLAG (0x01)
#define RT_EVENT_DEFAULT (0x0E)
#define RT_EVENT_WITH_FLAG (0x0B)

#define RT_EVENT_DDSYNC_NS 0x01U
#define RT_EVENT_STREAM_MARK 0x02U
#define RT_EVENT_DDSYNC 0x04U
#define RT_EVENT_TIME_LINE 0x08U

/**
* @ingroup dvrt_event
@@ -104,6 +115,16 @@ RTS_API rtError_t rtEventSynchronize(rtEvent_t event);
*/
RTS_API rtError_t rtEventQuery(rtEvent_t event);

/**
* @ingroup dvrt_event
* @brief Queries an event's wait status
* @param [in] event event to query
* @param [in out] EVENT_WAIT_STATUS status
* @return EVENT_STATUS_COMPLETE for complete
* @return EVENT_STATUS_NOT_READY for not complete
*/
RTS_API rtError_t rtEventQueryWaitStatus(rtEvent_t event, rtEventWaitStatus_t *status);

/**
* @ingroup dvrt_event
* @brief computes the elapsed time between events.
@@ -176,6 +197,18 @@ RTS_API rtError_t rtNotifyRecord(rtNotify_t notify, rtStream_t stream);
*/
RTS_API rtError_t rtNotifyWait(rtNotify_t notify, rtStream_t stream);

/**
* @ingroup dvrt_event
* @brief Wait for a notify with time out
* @param [in] notify_ notify to be wait
* @param [in] stream_ input stream
* @param [in] timeOut input timeOut
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
* @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx
*/
RTS_API rtError_t rtNotifyWaitWithTimeOut(rtNotify_t notify_, rtStream_t stream_, uint32_t timeOut);

/**
* @ingroup dvrt_event
* @brief Name a notify


+ 66
- 17
third_party/fwkacllib/inc/runtime/kernel.h View File

@@ -111,6 +111,16 @@ typedef struct rtKernelInfo {
uint32_t module_size;
} *rtKernelInfo_t;

/**
* @ingroup rt_kernel
* @brief op name
*/
typedef struct rtKernelLaunchNames {
const char *soName; // defined for so name
const char *kernelName; // defined for kernel type name
const char *opName; // defined for operator name
} rtKernelLaunchNames_t;

/**
* @ingroup rt_KernelConfigDump
* @brief device dump type
@@ -173,13 +183,7 @@ typedef void (*rtCallback_t)(void *fnData);
* @ingroup rt_kernel
* @brief magic number of elf binary for aicube
*/
#define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41415247

/**
* @ingroup rt_kernel
* @brief magic number of elf binary for aivector
*/
#define RT_DEV_BINARY_MAGIC_ELF_AIVECTOR 0x41415248
#define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41494343

/**
* @ingroup rt_kernel_flags
@@ -192,14 +196,14 @@ typedef void (*rtCallback_t)(void *fnData);
#define RT_KERNEL_CUSTOM_AICPU (0x08)

// STARS topic scheduler sqe : topic_type
#define RT_KERNEL_DEVICE_FIRST (0X10)
#define RT_KERNEL_HOST_ONLY (0X20)
#define RT_KERNEL_HOST_FIRST (0X30)
#define RT_KERNEL_DEVICE_FIRST (0x10)
#define RT_KERNEL_HOST_ONLY (0x20)
#define RT_KERNEL_HOST_FIRST (0x40)

/**
* @ingroup rt_kernel
* @brief kernel mode
*/
**/
#define RT_DEFAULT_KERNEL_MODE (0x00)
#define RT_NORMAL_KERNEL_MODE (0x01)
#define RT_ALL_KERNEL_MODE (0x02)
@@ -222,7 +226,7 @@ RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **handle);

/**
* @ingroup rt_kernel
* @brief register device binary
* @brief register device binary with all kernel
* @param [in] bin device binary description
* @param [out] handle device binary handle
* @return RT_ERROR_NONE for ok
@@ -341,7 +345,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *
* @ingroup rt_kernel
* @brief launch kernel with handle to device
* @param [in] handle program
* @param [in] devFunc device function description
* @param [in] devFunc device function description.
* @param [in] blockDim block dimentions
* @param [in] args argments address for kernel function
* @param [in] argsSize argements size
@@ -352,7 +356,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtKernelLaunchWithHandle(void *handle, const void *devFunc, uint32_t blockDim, void *args, uint32_t argsSize,
rtSmDesc_t *smDesc, rtStream_t stream, const void *kernelInfo);
rtSmDesc_t *smDesc, rtStream_t stream_, const void *kernelInfo);

/**
* @ingroup rt_kernel
@@ -371,7 +375,7 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim
rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags);

/**
* @ingroup rt_kernel
* @ingroup rt_kernel(abandoned)
* @brief launch kernel to device
* @param [in] args argments address for kernel function
* @param [in] argsSize argements size
@@ -383,7 +387,21 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim
RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream);

/**
* @ingroup rt_kernel
* @ingroup rt_kernel(in use)
* @brief launch kernel to device
* @param [in] opName opkernel name
* @param [in] args argments address for kernel function
* @param [in] argsSize argements size
* @param [in] flags launch flags
* @param [in] stream associated stream
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtKernelLaunchFwk(const char *opName, void *args, uint32_t argsSize, uint32_t flags,
rtStream_t rtStream);

/**
* @ingroup rt_kernel(abandoned)
* @brief launch cpu kernel to device
* @param [in] soName so name
* @param [in] kernelName kernel name
@@ -399,7 +417,22 @@ RTS_API rtError_t rtCpuKernelLaunch(const void *soName, const void *kernelName,
uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream);

/**
* @ingroup rt_kernel
* @ingroup rt_kernel(in use)
* @brief launch cpu kernel to device
* @param [in] launchNames names for kernel launch
* @param [in] blockDim block dimentions
* @param [in] args argments address for kernel function
* @param [in] argsSize argments size
* @param [in] smDesc shared memory description
* @param [in] stream associated stream
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtAicpuKernelLaunch(const rtKernelLaunchNames_t *launchNames,
uint32_t blockDim, const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream);

/**
* @ingroup rt_kernel(abandoned)
* @brief launch cpu kernel to device with dump identifier
* @param [in] soName so name
* @param [in] kernelName kernel name
@@ -416,6 +449,22 @@ RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kern
const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream,
uint32_t flags);

/**
* @ingroup rt_kernel(in use)
* @brief launch cpu kernel to device with dump identifier
* @param [in] launchNames names for kernel launch
* @param [in] blockDim block dimentions
* @param [in] args argments address for kernel function
* @param [in] argsSize argments size
* @param [in] smDesc shared memory description
* @param [in] stream associated stream
* @param [in] flag dump flag or others function flag
* @return RT_ERROR_NONE for ok
* @return RT_ERROR_INVALID_VALUE for error input
*/
RTS_API rtError_t rtAicpuKernelLaunchWithFlag(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim,
const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags);

/**
* @ingroup rt_kernel
* @brief L1 fusion dump addr transfered to device


+ 11
- 0
third_party/fwkacllib/inc/runtime/mem.h View File

@@ -116,6 +116,9 @@ typedef enum tagRtMemInfoType {

typedef enum tagRtRecudeKind {
RT_MEMCPY_SDMA_AUTOMATIC_ADD = 10, // D2D, SDMA inline reduce, include 1P, and P2P
RT_MEMCPY_SDMA_AUTOMATIC_MAX = 11,
RT_MEMCPY_SDMA_AUTOMATIC_MIN = 12,
RT_MEMCPY_SDMA_AUTOMATIC_EQUAL = 13,
RT_RECUDE_KIND_END
} rtRecudeKind_t;

@@ -123,6 +126,14 @@ typedef enum tagRtDataType {
RT_DATA_TYPE_FP32 = 0, // fp32
RT_DATA_TYPE_FP16 = 1, // fp16
RT_DATA_TYPE_INT16 = 2, // int16
RT_DATA_TYPE_INT4 = 3, // int4
RT_DATA_TYPE_INT8 = 4, // int8
RT_DATA_TYPE_INT32 = 5, // int32
RT_DATA_TYPE_BFP16 = 6, // bfp16
RT_DATA_TYPE_BFP32 = 7, // bfp32
RT_DATA_TYPE_UINT8 = 8, // uint8
RT_DATA_TYPE_UINT16= 9, // uint16
RT_DATA_TYPE_UINT32= 10,// uint32
RT_DATA_TYPE_END
} rtDataType_t;



+ 11
- 2
third_party/fwkacllib/inc/runtime/rt_model.h View File

@@ -135,12 +135,13 @@ typedef struct tagAllKernelTaskInfo {
uint16_t argsCount;
uint16_t argsSize;
uint16_t reserved;
const void *dev_func;
void *devfunc;
void *handle;
uint8_t *smDesc;
uint8_t *args;
uint16_t *argsOffset;
} rtAllKernelTaskInfo_t;

typedef struct tagKernelTaskInfoEx {
uint32_t flags;
uint32_t argsSize;
@@ -198,6 +199,13 @@ typedef struct tagProfilerTraceTaskInfo {
uint32_t reserved[6];
} rtProfilerTrace_t;

typedef struct tagProfilerTraceExTaskInfo {
uint64_t profilerTraceId;
uint64_t modelId;
uint16_t tagId;
uint8_t reserved[22];
} rtProfilerTraceEx_t;

typedef struct tagrtMemcpyAsyncTaskInfo {
void *dst;
uint64_t destMax;
@@ -265,7 +273,7 @@ typedef struct tagTaskInfo {
union {
rtKernelTaskInfoEx_t kernelTaskEx;
rtKernelTaskInfo_t kernelTask;
rtAllKernelTaskInfo_t allkernelTask;
rtAllKernelTaskInfo_t allKernelTask;
rtEventTaskInfo_t eventTask;
rtStreamSwitchTaskInfo_t streamSwitchTask;
rtStreamActiveTaskInfo_t streamActiveTask;
@@ -273,6 +281,7 @@ typedef struct tagTaskInfo {
rtLabelSwitchTaskInfo_t labelSwitchTask;
rtLabelGotoTaskInfo_t labelGotoTask;
rtProfilerTrace_t profilertraceTask;
rtProfilerTraceEx_t profilertraceExTask;
rtMemcpyAsyncTaskInfo_t memcpyAsyncTask;
rtNotifyTaskInfo_t notifyTask;
rtReduceAsyncTaskInfo_t reduceAsyncTask;


+ 30
- 1
third_party/fwkacllib/inc/toolchain/prof_callback.h View File

@@ -108,7 +108,19 @@ enum MsprofCtrlCallbackType {
MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env
MSPROF_CTRL_INIT_ACL_JSON, // start profiling with acl.json
MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options
MSPROF_CTRL_FINALIZE // stop profiling
MSPROF_CTRL_FINALIZE, // stop profiling
MSPROF_CTRL_REPORT_FUN_P, // for report callback
MSPROF_CTRL_PROF_SWITCH_ON, // for prof switch on
MSPROF_CTRL_PROF_SWITCH_OFF // for prof switch off
};

#define MSPROF_MAX_DEV_NUM (64)

struct MsprofCommandHandle {
uint64_t profSwitch;
uint32_t devNums; // length of device id list
uint32_t devIdList[MSPROF_MAX_DEV_NUM];
uint32_t modelId;
};

/**
@@ -129,6 +141,23 @@ typedef int32_t (*MsprofCtrlCallback)(uint32_t type, void *data, uint32_t len);
*/
typedef void (*MsprofSetDeviceCallback)(uint32_t devId, bool isOpenDevice);

/*
* @name MsprofInit
* @brief Profiling module init
* @param [in] dataType: profiling type: ACL Env/ACL Json/GE Option
* @param [in] data: profiling switch data
* @param [in] dataLen: Length of data
* @return 0:SUCCESS, >0:FAILED
*/
int32_t MsprofInit(uint32_t dataType, void *data, uint32_t dataLen);

/*
* @name AscendCL
* @brief Finishing Profiling
* @param NULL
* @return 0:SUCCESS, >0:FAILED
*/
int32_t MsprofFinalize();
#ifdef __cplusplus
}
#endif


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save