diff --git a/.clang-format b/.clang-format index e7f9d935..6faea40d 100644 --- a/.clang-format +++ b/.clang-format @@ -52,7 +52,6 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true -DerivePointerAlignment: true DisableFormat: false ExperimentalAutoDetectBinPacking: false FixNamespaceComments: true @@ -94,7 +93,7 @@ PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left +PointerAlignment: Right RawStringFormats: - Language: Cpp Delimiters: diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 313f1ff3..1872b4c2 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -95,6 +95,7 @@ target_link_libraries(ge_common PRIVATE $<$>:$> $<$>:$> $<$>:$> + $<$>:$> static_mmpa -Wl,--no-as-needed graph @@ -155,6 +156,7 @@ target_link_libraries(ge_common_static PRIVATE $<$>:$> $<$>:$> $<$>:$> + $<$>:$> ascend_protobuf_static json c_sec diff --git a/ge/common/dump/dump_properties.cc b/ge/common/dump/dump_properties.cc index 010347c0..84bdb7bf 100644 --- a/ge/common/dump/dump_properties.cc +++ b/ge/common/dump/dump_properties.cc @@ -163,7 +163,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDum GELOGE(PARAM_INVALID, "[Check][Param] the dumpPath:%s, is invalid.", input.c_str()); return PARAM_INVALID; } - if (mmAccess2(trusted_path, R_OK | W_OK) != EN_OK) { + if (mmAccess2(trusted_path, M_R_OK | M_W_OK) != EN_OK) { REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), std::vector({ "ge.exec.dumpPath", diff --git a/ge/common/dump/exception_dumper.cc b/ge/common/dump/exception_dumper.cc index c8ec3d35..c41da551 100644 --- a/ge/common/dump/exception_dumper.cc +++ b/ge/common/dump/exception_dumper.cc @@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector &ex uint64_t proto_size = dump_data.ByteSizeLong(); std::unique_ptr proto_msg(new (std::nothrow) char[proto_size]); + GE_CHECK_NOTNULL(proto_msg); bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); if (!ret || proto_size == 0) { REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index f258dffe..44ba3131 100755 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -186,6 +186,8 @@ target_include_directories(ge_executor SYSTEM PRIVATE ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### $<$>:${GE_DEPEND_DIR}/inc> + $<$>:$> + $<$>:$> #### blue zone #### $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> @@ -251,6 +253,8 @@ target_link_libraries(ge_executor_shared PRIVATE $<$>:$> $<$>:$> $<$>:$> + $<$>:$> + $<$>:$> -Wl,--no-as-needed ge_common runtime diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index 4302bff3..7cb6d556 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() { return false; } - rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); + rt_ret = rtLabelListCpy(reinterpret_cast(label_list.data()), label_list.size(), label_info_, label_info_size); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); return false; diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index d38e89fe..2816f170 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -32,7 +32,6 @@ #include "graph/ge_attr_value.h" #include "graph/ge_context.h" #include "external/graph/ge_error_codes.h" -#include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" #include "graph/optimize/common/params.h" #include "external/graph/types.h" @@ -707,7 +706,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { GE_CHECK_NOTNULL(kernel_buffer.GetData()); std::vector data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); - tbe_kernel = std::make_shared(kernel_name, std::move(data)); + tbe_kernel = MakeShared(kernel_name, std::move(data)); GE_CHECK_NOTNULL(tbe_kernel); GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index 5dee37d6..67289f73 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -793,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP GELOGI("Start AutoFindBpOpIndex"); NodePtr bp_node = nullptr; uint32_t current_idx = 0; - uint32_t netoutput_idx = 0; for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -811,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { if (bp_node == nullptr) { bp_node = node; - netoutput_idx = current_idx - 1; } } if (graph->GetNeedIteration()) { @@ -836,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP if (bp_node == nullptr) { GELOGW("not find bp_node."); return SUCCESS; - } else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) { - profiling_point.bp_index = netoutput_idx; - GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx); - } else { - profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node); } - return SUCCESS; + return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index); } -uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const { - uint32_t last_bp = 0; +Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node, + uint32_t &bp_index) const { + bp_index = 0; + auto target_desc = target_node->GetOpDesc(); + GE_CHECK_NOTNULL(target_desc); OpDescPtr bp_op_desc = nullptr; - for (auto &in_anchor : bp_node->GetAllInDataAnchors()) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { - continue; - } - auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL(out_node_desc); - if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) { - bp_op_desc = out_node_desc; + for (auto &in_node : target_node->GetInAllNodes()) { + GE_CHECK_NOTNULL(in_node); + auto in_node_desc = in_node->GetOpDesc(); + GE_CHECK_NOTNULL(in_node_desc); + if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) && + (in_node_desc->GetStreamId() == target_desc->GetStreamId())){ + bp_op_desc = in_node_desc; } - GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId()); } if (bp_op_desc == nullptr) { - return last_bp; + GELOGI("Did not find bp node."); + return SUCCESS; } uint32_t current_idx = 0; for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { @@ -871,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const GE_CHECK_NOTNULL(op_desc); current_idx++; if (op_desc->GetName() == bp_op_desc->GetName()) { - last_bp = current_idx; - GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp); + bp_index = current_idx; + GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index); break; } } - return last_bp; + GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(), + bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId()); + return SUCCESS; } Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, diff --git a/ge/graph/build/task_generator.h b/ge/graph/build/task_generator.h index 6f460906..5d204c3c 100755 --- a/ge/graph/build/task_generator.h +++ b/ge/graph/build/task_generator.h @@ -116,7 +116,7 @@ class TaskGenerator { Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, vector &all_reduce_nodes) const; - uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const; + Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const; Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, ProfilingPoint &profiling_point) const; diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_task_info.cc index bfb6e24b..07ad63ca 100755 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.cc @@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne GE_CHECK_NOTNULL(op_desc); args_addr = std::unique_ptr(new (std::nothrow) uint8_t[args_size_]); + GE_CHECK_NOTNULL(args_addr); errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); if (sec_ret != EOK) { REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); @@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k // copy args to new host memory args_addr = std::unique_ptr(new (std::nothrow) uint8_t[args_size_]); + GE_CHECK_NOTNULL(args_addr); GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); if (sec_ret != EOK) { diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index ced8465f..89a4e45b 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -194,35 +194,6 @@ ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_na return SUCCESS; } -ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { - GE_CHECK_NOTNULL(base_ptr); - GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str()); - - VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; - uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset; - - return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr, - var_broadcast_info.input_size, session_id_); -} - -ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { - GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str()); - - VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; - // subgraph base_ptr could be nullptr, task it as base 0 - uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset; - - return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name, - var_tensor_desc, session_id_); -} - -ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { - return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); -} - bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { @@ -638,16 +609,6 @@ bool VarManager::IsVarExist(const std::string &var_name) { return var_resource_->IsVarExist(var_name); } -ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, - uint8_t *base_ptr) { - std::lock_guard lock(mutex_); - if (var_resource_ == nullptr) { - GELOGW("VarManager has not been init."); - return ge::INTERNAL_ERROR; - } - return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr); -} - ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { std::lock_guard lock(mutex_); GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); @@ -701,16 +662,6 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); } -ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { - std::lock_guard lock(mutex_); - if (var_resource_ == nullptr) { - GELOGW("VarManager has not been init."); - return ge::INTERNAL_ERROR; - } - return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr); -} - bool VarManager::IsVarAddr(const int64_t &offset) { std::lock_guard lock(mutex_); if (var_resource_ == nullptr) { diff --git a/ge/graph/manager/graph_var_manager.h b/ge/graph/manager/graph_var_manager.h index 736466c4..f2b68e79 100755 --- a/ge/graph/manager/graph_var_manager.h +++ b/ge/graph/manager/graph_var_manager.h @@ -118,15 +118,6 @@ class VarResource { ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); - ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); - - ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); - - ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, - uint8_t *base_ptr); - Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { GELOGW("Var name: %s has already set.", var_name.c_str()); @@ -234,16 +225,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); - ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, - uint8_t *base_ptr); - ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); - ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, - uint8_t *base_ptr); - ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); diff --git a/ge/graph/manager/trans_var_data_utils.cc b/ge/graph/manager/trans_var_data_utils.cc index 4c25dff1..2e6ce454 100644 --- a/ge/graph/manager/trans_var_data_utils.cc +++ b/ge/graph/manager/trans_var_data_utils.cc @@ -415,72 +415,6 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src, return SUCCESS; } } // namespace -Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, - uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { - GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr."); - uint8_t *src_host_addr = nullptr; - int64_t src_addr_size = 0; - GE_MAKE_GUARD_RTMEM(src_host_addr); - GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id)); - - GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size); - GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, - "[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld", - src_addr_size, dst_addr_size); - - GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE)); - return SUCCESS; -} - -Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, - const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { - GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. "); - uint8_t *host_addr = nullptr; - GE_MAKE_GUARD_RTMEM(host_addr); - GE_CHK_RT_RET(rtMallocHost(reinterpret_cast(&host_addr), src_addr_size)); - GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST)); - - GE_CHK_STATUS_RET( - SyncTensorToDevice(var_name, reinterpret_cast(host_addr), src_addr_size, dst_tensor_desc, session_id)); - - return SUCCESS; -} - -Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, - uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) { - GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed"); - - uint8_t *src_addr = nullptr; - GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); - uint8_t *mem_addr = - src_addr - - static_cast(static_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); - GE_CHK_RT_RET(rtMallocHost(reinterpret_cast(host_addr), src_tensor_size)); - - GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); - - GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size); - return SUCCESS; -} - -Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, - const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { - uint8_t *dst_addr = nullptr; - GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); - uint8_t *mem_addr = - dst_addr - - static_cast(static_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); - GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); - - GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); - - return SUCCESS; -} - Status TransVarDataUtils::TransAllVarData(const vector &variable_nodes, uint64_t session_id, rtContext_t context, diff --git a/ge/graph/manager/trans_var_data_utils.h b/ge/graph/manager/trans_var_data_utils.h index d5096ef2..174efbb3 100755 --- a/ge/graph/manager/trans_var_data_utils.h +++ b/ge/graph/manager/trans_var_data_utils.h @@ -29,11 +29,6 @@ namespace ge { class TransVarDataUtils { public: - static ge::Status SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, - uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id_); - static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, - const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); - static ge::Status TransAllVarData(const std::vector &variable_nodes, uint64_t session_id, rtContext_t context, @@ -41,12 +36,6 @@ class TransVarDataUtils { uint32_t thread_num = 16); static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); - - private: - static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, - uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); - static ge::Status SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, - const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); }; } // namespace ge diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index b74eb5cd..19682d23 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -46,50 +46,116 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, bool AllNodesIn(const Node::Vistor &nodes, const std::unordered_set &nodes_set) { return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { + return nodes_set.count(n) == 0; + }); +} + +bool AnyNodesIn(const Node::Vistor &nodes, const std::unordered_set &nodes_set) { + return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_set.count(n) > 0; }); } -void AddNextIterNodes(const NodePtr &cur_node, GEPass::GraphLevelState &g_state) { - const auto &nodes_suspend = g_state.nodes_suspend; +bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) { + if (node == nullptr) { + GELOGW("node is null"); + return false; + } + + if (g_state.nodes_deleted.count(node) > 0) { + GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); + return false; + } + + if (g_state.nodes_last.count(node) != 0) { + return false; + } + + if (!node->IsAllInNodesSeen(g_state.nodes_seen)) { + return false; + } + + // 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同 + // TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend, + // 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响 + // 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue, + // 这样处理就不需要在这里做额外的确认了 + if (g_state.nodes_suspend.count(node) > 0) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + return true; +} + +void CollectOutNodesBeforePass(const NodePtr &node, std::unordered_set &out_nodes_before_pass) { + for (const auto &out_node : node->GetOutNodes()) { + out_nodes_before_pass.insert(out_node); + } +} + +void AddNextIterNodes(const NodePtr &cur_node, std::unordered_set &out_nodes_before_pass, + GEPass::GraphLevelState &g_state) { for (auto &node : cur_node->GetOutNodes()) { if (node == nullptr) { continue; } - if (g_state.nodes_last.count(node) != 0) { - continue; + if (out_nodes_before_pass.erase(node) == 0) { + // after pass node , new output node come up + GELOGD("New output nodes %s come up after pass %s.", node->GetName().c_str(), cur_node->GetName().c_str()); + } + + if (IsNodeReadyToQueue(node, g_state)) { + g_state.AddNodeToQueueIfNotSeen(node); } - if (nodes_suspend.count(node) > 0) { - GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); + } + // A-->B-->C + // \ + // D--->E + // If B has been delete after pass, two case need to consider + // 1. A & C & E has been repass by B. good choice + // 2. A & C & E not added to repass, C will not pass because no one trigger it. + // while E will pass because D will trigger it. + // So here we need add node which has no input_node to queue. + for (const auto &node : out_nodes_before_pass) { + if (!node->GetInAllNodes().empty()) { + GELOGD("Node %s used to be output of node %s, but after pass it doesnt. " + "It may triggered by other node, so no need add to queue now."); continue; } - - if (node->IsAllInNodesSeen(g_state.nodes_seen) && AllNodesIn(node->GetInAllNodes(), nodes_suspend)) { + if (IsNodeReadyToQueue(node, g_state)) { + // unlink edge may happen, add these node to queue otherwise they can not pass + GELOGI("Node %s may lost from cur node, add to queue if not seen.", + node->GetName().c_str(), cur_node->GetName().c_str()); g_state.AddNodeToQueueIfNotSeen(node); } } } -void AddImmediateRepassNodesToQueue(NodePtr &cur_node, const std::pair &name_to_pass, - const std::unordered_set &nodes_im_re_pass, +void AddImmediateRepassNodesToQueue(NodePtr &cur_node, + const std::unordered_map re_pass_imm_nodes_to_pass_names, GEPass::GraphLevelState &g_state) { - for (const auto &node : nodes_im_re_pass) { - if (node == nullptr) { - GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", name_to_pass.first.c_str(), + for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) { + auto repass_imm_node = node_2_pass_names.first; + if (repass_imm_node == nullptr) { + GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", + node_2_pass_names.second.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str()); continue; } - if (g_state.nodes_passed.count(node) > 0) { - g_state.AddNodeToQueueFront(node); - continue; - } - // exp: constant folding add new const need repass immediate - if (AllNodesIn(node->GetInAllNodes(), g_state.nodes_passed)) { - g_state.AddNodeToQueueFront(node); + if (g_state.nodes_passed.count(repass_imm_node) > 0) { + GELOGD("The node %s specified by pass %s has been passed, it will repass immediately", + repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str()); + g_state.AddNodeToQueueFront(repass_imm_node); continue; } GELOGW("The node %s specified by pass %s has un-passed in_nodes, it will not repass immediately", - node->GetName().c_str(), name_to_pass.first.c_str()); + repass_imm_node->GetName().c_str(), node_2_pass_names.second.c_str()); } } @@ -103,23 +169,18 @@ void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) { g_state.nodes_last.clear(); } -void SuspendAndResume(const std::string &pass_name, - const std::unordered_set &nodes_suspend, - const std::unordered_set &nodes_resume, - GEPass::GraphLevelState &g_state) { - // TODO 当前没有记录NodePass中suspend和resume的顺序,因此无法辨别NodePass中是先做Suspend还是Resume。 - // 因此此处的简单处理是如果在NodePass的过程中,触发了suspend/resume,那么框架以resume为准 - // 更好的处理方式是,在NodePass做suspend/resume时,做顺序的记录,在此函数中按序做回放 - for (const auto &node : nodes_suspend) { - GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); - g_state.nodes_suspend.insert(node); - } - - for (const auto &node : nodes_resume) { +void AddResumeNodesToQueue(const std::unordered_map resume_nodes_to_pass_names, + GEPass::GraphLevelState &g_state) { + // Currently we dont keep the order of suspend nodes and resume nodes, so its hard to know + // which one comes first. Simple way : if a node both have suspend & resume state, we will resume it. + // Better way: keep the order when suspend/resume a node, and in this func suspend/resume in order. + for (const auto &node_2_pass_names : resume_nodes_to_pass_names) { + auto node = node_2_pass_names.first; if (g_state.nodes_suspend.erase(node) > 0) { if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) { g_state.nodes.push_back(node); - GELOGD("Node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); + GELOGD("Node %s has been resumed by pass %s, add to queue.", + node->GetName().c_str(), node_2_pass_names.second.c_str()); } } } @@ -154,36 +215,6 @@ void ClearOption(NamesToPass names_to_pass) { name_to_pass.second->ClearOptions(); } } - -bool ShouldNodePassActually(const NodePtr &node, const GEPass::GraphLevelState &g_state) { - if (node == nullptr) { - GELOGW("node is null"); - return false; - } - // 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,删除了node的输出节点, - // 那么会出现:已经删除的节点出现在queue中,并且被pop出来,因此这里做确认,如果node已经被删除过了,就跳过pass - if (g_state.nodes_deleted.count(node) > 0) { - GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); - return false; - } - - // 因为在PassNode之前,会首先将node的输出节点添加queue,因此若在pass node时,suspend了node的输出节点,后续逻辑与上面相同 - // TODO 需要注意的是,这里的保证是一次”尽力而为“,若pass node时,将node之前的节点`A`添加到了suspend, - // 那么`A`节点的后继和间接后继节点的pass不会受到suspend的影响 - // 理论上来说,如果在pass node之前,首先收集node的输出节点,在pass后,将输出节点做suspend、delete的去除,然后加queue, - // 这样处理就不需要在这里做额外的确认了 - if (g_state.nodes_suspend.count(node) > 0) { - GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", - node->GetName().c_str()); - return false; - } - if (!AllNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) { - GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", - node->GetName().c_str()); - return false; - } - return true; -} } // namespace Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, @@ -256,18 +287,20 @@ void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names } Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + std::unordered_map resume_nodes_to_pass_names; for (auto &name_to_pass : names_to_passes) { name_to_pass.second->init(); auto ret = name_to_pass.second->OnSuspendNodesLeaked(); if (ret != SUCCESS) { - // todo error + GELOGE(ret, "Internal Error happened when pass %s handle on suspend nodes leaked.", + name_to_pass.first.c_str()); return ret; } - SuspendAndResume(name_to_pass.first, - name_to_pass.second->GetNodesSuspend(), - name_to_pass.second->GetNodesResume(), - g_state); + for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ + resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); + } } + AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); return SUCCESS; } @@ -283,11 +316,13 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { if (!g_state.nodes_suspend.empty()) { auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state); if (ret != SUCCESS) { - // todo log + GELOGE(ret, "Failed to handle leaked suspend nodes, break base pass."); return ret; } if (g_state.nodes.empty()) { - // todo 报错,因为suspend泄露场景,没有子类做进一步的resume,此处可能已经彻底泄露,需要报错 + // There are suspend nodes leaked, but no pass resume it + GELOGE(INTERNAL_ERROR, "There are suspend nodes but no pass resume, which means" + "some nodes in this graph never pass."); return INTERNAL_ERROR; } } @@ -305,6 +340,7 @@ Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLev RepassLevelState rp_state; do { for (auto &node : rp_state.nodes_re_pass) { + GELOGD("Add node %s to queue for re-pass.", node->GetName().c_str()); g_state.AddNodeToQueue(node); } rp_state.nodes_re_pass.clear(); @@ -312,12 +348,14 @@ Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLev while (!g_state.nodes.empty()) { auto node = g_state.PopFront(); - (void)rp_state.nodes_re_pass.erase(node); // todo 回忆一下为什么 - if (!ShouldNodePassActually(node, g_state)) { - continue; + if (g_state.nodes_deleted.count(node) > 0) { + GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); } + (void)rp_state.nodes_re_pass.erase(node);// todo why g_state.nodes_seen.insert(node.get()); // todo 为什么这里seen - AddNextIterNodes(node, g_state); + + std::unordered_set out_nodes_before_pass; + CollectOutNodesBeforePass(node, out_nodes_before_pass); auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state); if (ret != SUCCESS) { @@ -325,6 +363,7 @@ Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLev node->GetType().c_str(), ret); return ret; } + AddNextIterNodes(node, out_nodes_before_pass, g_state); } AddLastNodesToQueue(g_state); } while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes); @@ -405,10 +444,11 @@ Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes name_to_pass.second->init(); auto result = name_to_pass.second->Run(node); if (result != SUCCESS) { - REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " - "%u, the passes will be terminated immediately.", + REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", name_to_pass.first.c_str(), + node->GetName().c_str(), result); + GELOGE(INTERNAL_ERROR, + "[Process][Pass] %s on node %s failed, result " + "%u, the passes will be terminated immediately.", name_to_pass.first.c_str(), node->GetName().c_str(), result); return result; } @@ -421,23 +461,30 @@ Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes g_state.nodes_passed.insert(node); + std::unordered_map repass_imm_nodes_to_pass_names; + std::unordered_map resume_nodes_to_pass_names; + // if multi pass add one node to repass immediately, here need to remove duplication for (const auto &name_to_pass : names_to_passes) { - PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, - name_to_pass.second->GetNodesNeedRePass(), + PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, name_to_pass.second->GetNodesNeedRePass(), rp_state.nodes_re_pass); + // collect imm_node && resume_node among these passes + for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()) { + repass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ","); + } + for (const auto &resume_node : name_to_pass.second->GetNodesResume()) { + resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); + } - AddImmediateRepassNodesToQueue(node, name_to_pass, - name_to_pass.second->GetNodesNeedRePassImmediately(), - g_state); - SuspendAndResume(name_to_pass.first, - name_to_pass.second->GetNodesSuspend(), - name_to_pass.second->GetNodesResume(), - g_state); - + for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) { + GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(), + name_to_pass.first.c_str()); + g_state.nodes_suspend.insert(suspend_node); + } const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); } - + AddImmediateRepassNodesToQueue(node, repass_imm_nodes_to_pass_names, g_state); + AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); return SUCCESS; } } // namespace ge diff --git a/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc index d586a57b..819c3b40 100755 --- a/ge/graph/passes/folding_pass.cc +++ b/ge/graph/passes/folding_pass.cc @@ -363,7 +363,7 @@ Status FoldingPass::ConnectNodeToInAnchor(InDataAnchorPtr &in_anchor, NodePtr &n in_anchor->GetIdx()); return INTERNAL_ERROR; } - AddImmediateRePassNode(node); + AddRePassNodesWithInOut(node); return SUCCESS; } } // namespace ge diff --git a/ge/graph/passes/infer_base_pass.cc b/ge/graph/passes/infer_base_pass.cc index 25c45677..231fb967 100644 --- a/ge/graph/passes/infer_base_pass.cc +++ b/ge/graph/passes/infer_base_pass.cc @@ -84,8 +84,11 @@ Status InferBasePass::Run(NodePtr &node) { bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } void InferBasePass::AddChangedNodesImmediateRepass(const std::set &changed_nodes) { -// need passed_nodes set to solve the problem that multi-input operators do repass in advance. -// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. + // need passed_nodes set to solve the problem that multi-input operators do repass in advance. + // when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. + for (const auto &node : changed_nodes) { + AddImmediateRePassNode(node); + } } graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set &changed_nodes) { diff --git a/ge/graph/passes/infer_base_pass.h b/ge/graph/passes/infer_base_pass.h index 911d7fc7..3900b5db 100644 --- a/ge/graph/passes/infer_base_pass.h +++ b/ge/graph/passes/infer_base_pass.h @@ -36,7 +36,7 @@ class InferBasePass : public BaseNodePass { * @param dst, output TensorDesc to be updated * @return */ - virtual graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0; + virtual graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0; /** * Update the output TensorDesc for nodes which contain subgraphs. diff --git a/ge/graph/passes/infer_value_range_pass.cc b/ge/graph/passes/infer_value_range_pass.cc index 1d305fbb..03a18fdb 100644 --- a/ge/graph/passes/infer_value_range_pass.cc +++ b/ge/graph/passes/infer_value_range_pass.cc @@ -207,7 +207,7 @@ bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const { return has_unknown_value_range; } -graphStatus InferValueRangePass::UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { +graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { if (src == nullptr || dst == nullptr) { REPORT_CALL_ERROR("E19999", "While updating tensor desc, input desc is null."); GELOGE(GRAPH_FAILED, "[Param][check] While updating tensor desc, input desc is null."); @@ -301,12 +301,26 @@ graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, GeTensorPtr &output_ptr) { std::vector> value_range; (void)tensor_desc.GetValueRange(value_range); - if (static_cast(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) { - GELOGW("Value range of input %s is invalid.", tensor_desc.GetName().c_str()); + size_t value_range_data_num = value_range.size(); + auto tensor_shape = tensor_desc.GetShape(); + bool value_range_and_tensor_shape_matched = true; + if (tensor_shape.IsScalar()){ + // scalar tensor has only one value_range pair + if (value_range_data_num != 1) { + value_range_and_tensor_shape_matched = false; + } + } else { + // normal tensor, value_range size is equal to tensor shape size. + if (static_cast(value_range_data_num) != tensor_shape.GetShapeSize()) { + value_range_and_tensor_shape_matched = false; + } + } + if (!value_range_and_tensor_shape_matched) { + GELOGW("Input %s value range and tensor shape do not match. Value range size is %zu, tensor shape is %s.", + tensor_desc.GetName().c_str(), value_range_data_num, formats::ShapeToString(tensor_shape).c_str()); return GRAPH_PARAM_INVALID; } - size_t value_range_data_num = value_range.size(); unique_ptr buf(new (std::nothrow) T[value_range_data_num]()); if (buf == nullptr) { REPORT_INNER_ERROR("E19999", "New buf failed"); @@ -494,10 +508,16 @@ void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, co GELOGI("Output tensor of cpu kernel does not have data, no way to set value range."); return; } - for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) { + auto left_tensor_shape = left_tensor->GetTensorDesc().GetShape(); + for (auto j = 0; j < left_tensor_shape.GetShapeSize(); ++j) { auto left = static_cast(*(x + j)); auto right = static_cast(*(y + j)); - value_range.emplace_back(std::make_pair(left, right)); + value_range.emplace_back(left, right); + } + + if (left_tensor_shape.IsScalar()) { + GELOGD("When inferring value range, output tensors of cpu kernel are scalar tensors."); + value_range.emplace_back(static_cast(*x), static_cast(*y)); } } } // namespace ge diff --git a/ge/graph/passes/infer_value_range_pass.h b/ge/graph/passes/infer_value_range_pass.h index cc971d17..eb485c87 100644 --- a/ge/graph/passes/infer_value_range_pass.h +++ b/ge/graph/passes/infer_value_range_pass.h @@ -26,7 +26,7 @@ class InferValueRangePass : public InferBasePass { private: std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; - graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override; graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, GeTensorDescPtr &dst) override; diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 58ee0348..47fd77e4 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -90,6 +90,9 @@ Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(), anchor_2_node.second->GetType().c_str()); auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); + if (iter == graphs_2_suspend_nodes_.end()) { + continue; + } auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()]; if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) { suspend_nodes.nodes.push(anchor_2_node.second); @@ -102,7 +105,7 @@ Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { Status InferShapePass::Infer(NodePtr &node) { auto ret = SuspendV1LoopExitNodes(node); if (ret != SUCCESS) { - //todo LOG + GELOGE(ret, "Failed to suspend exit node in v1 control flow loop."); return ret; } bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); @@ -117,7 +120,9 @@ Status InferShapePass::Infer(NodePtr &node) { if (!is_unknown_graph) { auto inference_context = ShapeRefiner::CreateInferenceContext(node); GE_CHECK_NOTNULL(inference_context); - GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); + vector marks; + inference_context->GetMarks(marks); + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size()); op.SetInferenceContext(inference_context); } @@ -128,13 +133,16 @@ Status InferShapePass::Infer(NodePtr &node) { GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); return GRAPH_FAILED; } + UpdateCurNodeOutputDesc(node); if (!is_unknown_graph) { auto ctx_after_infer = op.GetInferenceContext(); if (ctx_after_infer != nullptr) { - GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); - if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { + vector marks; + ctx_after_infer->GetMarks(marks); + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) { GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), - ctx_after_infer->GetMarks().size()); + marks.size()); ShapeRefiner::PushToContextMap(node, ctx_after_infer); } } @@ -180,15 +188,29 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe return true; } -graphStatus InferShapePass::UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { - // refresh src itself - src->SetOriginShape(src->GetShape()); - src->SetOriginDataType(src->GetDataType()); - TensorUtils::SetRealDimCnt(*src, static_cast(src->GetOriginShape().GetDims().size())); - vector> src_shape_range; - src->GetShapeRange(src_shape_range); - src->SetOriginShapeRange(src_shape_range); +void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) { + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + GE_IF_BOOL_EXEC(output_tensor == nullptr, continue); + GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(), + output_tensor->SetOriginShape(output_tensor->GetShape())); + + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetOriginShape().GetDims() + .size())); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + // set output origin shape range + std::vector> range; + (void)output_tensor->GetShapeRange(range); + output_tensor->SetOriginShapeRange(range); + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + } +} +graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { changed = false; if (SameTensorDesc(src, dst)) { GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); @@ -213,7 +235,7 @@ graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { auto ret = op_desc->CallInferFunc(op); if (ret == GRAPH_PARAM_INVALID) { // Op ir no infer func, try to get infer func from operator factory - auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); + auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str()); if (node_op.IsEmpty()) { GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); return ret; @@ -318,7 +340,7 @@ graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vec Status InferShapePass::OnSuspendNodesLeaked() { auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); if (iter == graphs_2_suspend_nodes_.end()) { - // todo log warn + GELOGW("There is no suspend nodes on graph %s", GetCurrentGraphName().c_str()); return SUCCESS; } if (!iter->second.nodes.empty()) { diff --git a/ge/graph/passes/infershape_pass.h b/ge/graph/passes/infershape_pass.h index 2b7fdd3b..02af16ca 100644 --- a/ge/graph/passes/infershape_pass.h +++ b/ge/graph/passes/infershape_pass.h @@ -26,7 +26,7 @@ class InferShapePass : public InferBasePass { std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; graphStatus Infer(NodePtr &node) override; - graphStatus UpdateTensorDesc(GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override; graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, GeTensorDescPtr &dst) override; @@ -36,6 +36,7 @@ class InferShapePass : public InferBasePass { private: graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); + void UpdateCurNodeOutputDesc(NodePtr &node); Status SuspendV1LoopExitNodes(const NodePtr &node); struct SuspendNodes { std::stack nodes; diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index 27ea85fd..fec9c6d0 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -31,7 +31,6 @@ namespace ge { const int kValueIndexOutputIndex = 1; const size_t kCaseNoInput = 0; const size_t kCaseOneInput = 1; -const bool kWillRepassImmediately = true; Status MergePass::Run(NodePtr &node) { GELOGD("MergePass running"); @@ -83,14 +82,14 @@ Status MergePass::Run(NodePtr &node) { } auto in_node = in_data_nodes.at(0); if (IsMergeInputNeedOptimized(in_node)) { - if (IsolateAndDeleteNode(in_node, {0}, kWillRepassImmediately) != SUCCESS) { + if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) { REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed", in_node->GetName().c_str(), in_node->GetType().c_str()); GELOGE(FAILED, "[Remove][Node] %s failed.", in_node->GetName().c_str()); return FAILED; } } - return IsolateAndDeleteNode(node, merge_io_map, kWillRepassImmediately); + return IsolateAndDeleteNode(node, merge_io_map); } default: { // Case C: input_count > 1, the merge node can not be optimized diff --git a/ge/graph/passes/switch_dead_branch_elimination.cc b/ge/graph/passes/switch_dead_branch_elimination.cc index d67a8d01..3c6c57d0 100644 --- a/ge/graph/passes/switch_dead_branch_elimination.cc +++ b/ge/graph/passes/switch_dead_branch_elimination.cc @@ -28,7 +28,6 @@ namespace { const std::vector::size_type kDataInputIndex = 0; const std::vector::size_type kPredInputIndex = 1; const int kDefaultInputIndex = -1; -const bool kWillRepassImmediately = true; bool ParsePred(const ConstGeTensorPtr &tensor) { if (tensor == nullptr) { @@ -135,7 +134,7 @@ Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pre return FAILED; } switch_io_map[out_index] = kDataInputIndex; - return IsolateAndDeleteNode(node, switch_io_map, kWillRepassImmediately); + return IsolateAndDeleteNode(node, switch_io_map); } Status SwitchDeadBranchElimination::Run(NodePtr &node) { diff --git a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index 3cd26139..cc7f276e 100755 --- a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map aipp_params(new (std::nothrow) domi::AippOpParams()); + GE_CHECK_NOTNULL(aipp_params); ge::GeAttrValue::NAMED_ATTRS aipp_attr; GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, "[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 1634c8ce..fd3a4e91 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector &start_ auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); if (!IsAllDimsPositive(dims)) { REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", - node->GetName().c_str(), formats::ShapeToString(dims).c_str()); + node->GetName().c_str(), formats::ShapeToString(dims).c_str()); GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", node->GetName().c_str(), formats::ShapeToString(dims).c_str()); return INTERNAL_ERROR; diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index e0dd768d..229cce84 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -295,13 +295,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData ¤t_data, Hy } } tensor_desc->SetShape(shape); - args.input_desc[input_index] = tensor_desc; - GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str()); + GELOGD("Update shape[%s] of input[%zu] to [%s]", + shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str()); GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), "[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," "index = %zu, shape = [%s], model_id = %u.", input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); - GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size); + GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size); + TensorUtils::SetSize(*tensor_desc, tensor_size); + args.input_desc[input_index] = tensor_desc; } GE_CHECK_GE(tensor_size, 0); diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index 58da451c..2bb683c7 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, } HybridModelExecutor::~HybridModelExecutor() { - if (context_.rt_gen_context != nullptr) { - (void) rtCtxDestroy(context_.rt_gen_context); - } } Status HybridModelExecutor::Init() { @@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() { Status HybridModelExecutor::InitExecutionContext() { GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); - GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); context_.global_step = model_->GetGlobalStep(); diff --git a/ge/hybrid/executor/hybrid_model_pipeline_executor.cc b/ge/hybrid/executor/hybrid_model_pipeline_executor.cc index 45e61138..b5e66628 100644 --- a/ge/hybrid/executor/hybrid_model_pipeline_executor.cc +++ b/ge/hybrid/executor/hybrid_model_pipeline_executor.cc @@ -191,7 +191,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin } Status StageExecutor::InitExecutionContext() { - GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); context_.model = model_; diff --git a/ge/hybrid/executor/worker/task_compile_engine.cc b/ge/hybrid/executor/worker/task_compile_engine.cc index f7da9acd..491e0997 100755 --- a/ge/hybrid/executor/worker/task_compile_engine.cc +++ b/ge/hybrid/executor/worker/task_compile_engine.cc @@ -21,10 +21,17 @@ namespace ge { namespace hybrid { Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { - const auto &node_item = *node_state.GetNodeItem(); GE_CHECK_NOTNULL(context); + rtContext_t rt_gen_context = nullptr; + GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0)); + std::function callback = [&]() { + (void) rtCtxDestroy(rt_gen_context); + GE_CHK_RT(rtCtxSetCurrent(context->rt_context)); + }; + GE_MAKE_GUARD(rt_gen_context, callback); + + const auto &node_item = *node_state.GetNodeItem(); RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); - GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); if (context->ge_context != nullptr) { GetThreadLocalContext() = *context->ge_context; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index e80d9b90..554ddbbb 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1044,6 +1044,7 @@ Status HybridModelBuilder::InitConstantOps() { } else { var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); } + GE_CHECK_NOTNULL(var_tensor); } else { GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 8e87c6e2..77bd8efd 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -24,6 +24,8 @@ namespace ge { namespace hybrid { namespace { +const uint8_t kMaxTransCount = 3; +const uint32_t kTransOpIoSize = 1; const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char *const kNodeTypeRetVal = "_RetVal"; const std::set kControlOpTypes{ @@ -39,6 +41,25 @@ const std::set kMergeOpTypes{ MERGE, REFMERGE, STREAMMERGE }; +bool IsEnterFeedNode(NodePtr node) { + // For: Enter -> node + // For: Enter -> Cast -> node + // For: Enter -> TransData -> Cast -> node + for (uint8_t i = 0; i < kMaxTransCount; ++i) { + if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { + GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str()); + return true; + } + + const auto all_nodes = node->GetInDataNodes(); + if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) { + return false; + } + node = all_nodes.at(0); + } + return false; +} + Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { uint32_t parent_index = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { @@ -399,7 +420,7 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { data_anchors.emplace(anchor_index); } // If Enter feed Not Merge, take as root Node. - if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { + if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) { auto &data_anchors = node_item->enter_data_[this]; data_anchors.emplace(anchor_index); } @@ -419,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { node_item->root_ctrl_.emplace(this); } // If Enter feed control signal, take as root Node. - if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { + if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { node_item->enter_ctrl_.emplace(this); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); diff --git a/ge/ir_build/option_utils.cc b/ge/ir_build/option_utils.cc index 16586c4e..7287fe91 100755 --- a/ge/ir_build/option_utils.cc +++ b/ge/ir_build/option_utils.cc @@ -50,6 +50,8 @@ const std::set kBufferOptimizeSupportOption = {"l1_optimize", "l2_o const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision"; +const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL = "high_precision_for_all"; +const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL = "high_performance_for_all"; const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; const char *const kSplitError1 = "size not equal to 2 split by \":\""; @@ -57,7 +59,8 @@ const char *const kEmptyError = "can not be empty"; const char *const kFloatNumError = "exist float number"; const char *const kDigitError = "is not digit"; const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; -const char *const kSelectImplmodeError = "only support high_performance, high_precision"; +const char *const kSelectImplmodeError = "only support high_performance, high_precision, " + "high_precision_for_all, high_performance_for_all"; const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; const char *const kKeepDtypeError = "file not found"; @@ -782,7 +785,9 @@ Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std:: op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT; } else { if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT && - op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) { + op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON && + op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL && + op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL) { ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, {"--op_select_implmode", op_select_implmode.c_str(), kSelectImplmodeError}); diff --git a/ge/offline/main.cc b/ge/offline/main.cc index 4837653f..bc3b823d 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -143,7 +143,8 @@ DEFINE_string(output_type, "", DEFINE_string(op_select_implmode, "", "Optional; op select implmode! " - "Support high_precision, high_performance."); + "Support high_precision, high_performance, " + "high_precision_for_all, high_performance_for_all."); DEFINE_string(optypelist_for_implmode, "", "Optional; Nodes need use implmode selected in op_select_implmode " @@ -311,8 +312,8 @@ class GFlagUtils { "scenarios by using a configuration file.\n" " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" " --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n" - " --op_select_implmode Set op select implmode. Support high_precision, high_performance. " - "default: high_performance\n" + " --op_select_implmode Set op select implmode. Support high_precision, high_performance, " + "high_precision_for_all, high_performance_for_all. default: high_performance\n" " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n" " Separate multiple nodes with commas (,). Use double quotation marks (\") " "to enclose each argument. E.g.: \"node_name1,node_name2\"\n" diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 08a0fcbc..9a52a83d 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -95,35 +95,6 @@ Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_ho } return SUCCESS; } - -Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { - bool is_infer_depend = false; - bool is_host_mem = false; - GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); - bool need_d2h_cpy = is_infer_depend && !is_host_mem; - auto tasks = ge_model->GetModelTaskDefPtr()->task(); - int32_t kernel_task_num = 0; - for (int i = 0; i < tasks.size(); ++i) { - auto task_type = static_cast(tasks[i].type()); - if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { - const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() : - tasks[i].kernel_with_handle().context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == ccKernelType::TE) { - if (need_d2h_cpy) { - flag = true; - return SUCCESS; - } - kernel_task_num++; - if (kernel_task_num > 1) { - flag = true; - return SUCCESS; - } - } - } - } - return SUCCESS; -} } // namespace SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size) @@ -558,14 +529,15 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { return BuildTaskList(&resource, single_op); } -Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, const TaskDef &task_def, - DynamicSingleOp &single_op) { - auto task_type = static_cast(task_def.type()); - const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : - task_def.kernel_with_handle().context(); +Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) { + auto ge_model = model_helper_.GetGeModel(); + GE_CHECK_NOTNULL(ge_model); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == ccKernelType::TE) { + auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); + GE_CHECK_NOTNULL(compute_graph); + single_op.compute_graph_ = compute_graph; + if (tbe_tasks_.size() > 0) { + const auto &task_def = tbe_tasks_[0]; GELOGD("Building TBE task."); TbeOpTask *tbe_task = nullptr; GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task)); @@ -575,71 +547,81 @@ Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, cons tbe_task->stream_resource_ = stream_resource; } single_op.op_task_.reset(tbe_task); - } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { - GELOGD("Building AICPU_CC task"); - OpTask *task = nullptr; - uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; - GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id); - GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id)); - task->SetModelArgs(model_name_, model_id_); - single_op.op_task_.reset(task); - } else { - GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, - "[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", - context.kernel_type()); - REPORT_INNER_ERROR("E19999", - "BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.", - context.kernel_type()); - return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; - } - return SUCCESS; -} - -Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) { - auto ge_model = model_helper_.GetGeModel(); - GE_CHECK_NOTNULL(ge_model); - - auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - GE_CHECK_NOTNULL(compute_graph); - single_op.compute_graph_ = compute_graph; - auto tasks = ge_model->GetModelTaskDefPtr()->task(); - for (int i = 0; i < tasks.size(); ++i) { - const TaskDef &task_def = tasks[i]; - GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(), - task_def.DebugString().c_str()); + } else if (aicpu_tasks_.size() > 0) { + const auto &task_def = aicpu_tasks_[0]; auto task_type = static_cast(task_def.type()); - if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { - if (single_op.op_task_ != nullptr) { - GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks."); - REPORT_INNER_ERROR("E19999", - "BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks."); - return ACL_ERROR_GE_OP_TASK_TYPE_INVALID; - } - GE_CHK_STATUS_RET_NOLOG(BuildModelTaskKernel(stream_resource, task_def, single_op)); + if (task_type == RT_MODEL_TASK_KERNEL) { + GELOGD("Building AICPU_CC task"); + OpTask *task = nullptr; + uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; + GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id); + GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id)); + task->SetModelArgs(model_name_, model_id_); + single_op.op_task_.reset(task); } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { - if (single_op.op_task_ != nullptr) { - GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks."); - REPORT_INNER_ERROR("E19999", - "BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks."); - return ACL_ERROR_GE_OP_TASK_TYPE_INVALID; - } GELOGD("Building AICPU_TF task"); AiCpuTask *aicpu_task = nullptr; uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; GELOGI("Build dynamic singleOp TfTask, kernel_id = %lu", dynamic_singleop_kernel_id); GE_CHK_STATUS_RET_NOLOG(BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, dynamic_singleop_kernel_id)); if (aicpu_task->GetUnknownType() == DEPEND_COMPUTE) { - if (i >= tasks.size() - 1) { + if (aicpu_tasks_.size() < 2) { GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Task]The copy task of the fourth operator was not found."); REPORT_INNER_ERROR("E19999", "The copy task of the fourth operator was not found."); return ACL_ERROR_GE_PARAM_INVALID; } - ++i; - const TaskDef ©_task_def = tasks[i]; + const TaskDef ©_task_def = aicpu_tasks_[1]; GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex())); } aicpu_task->SetModelArgs(model_name_, model_id_); single_op.op_task_.reset(aicpu_task); + } + } + return SUCCESS; +} + +Status SingleOpModel::NeedHybridModel(GeModelPtr &ge_model, bool &need_hybrid_model) { + bool is_infer_depend = false; + bool is_host_mem = false; + GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); + bool need_d2h_cpy = is_infer_depend && !is_host_mem; + bool aicpu_multi_task = tbe_tasks_.size() >= 1 && aicpu_tasks_.size() >= 1; + bool aicore_multi_task = tbe_tasks_.size() > 1; + need_hybrid_model = need_d2h_cpy || aicore_multi_task || aicpu_multi_task; + return SUCCESS; +} + +Status SingleOpModel::ParseTasks() { + auto ge_model = model_helper_.GetGeModel(); + GE_CHECK_NOTNULL(ge_model); + + auto tasks = ge_model->GetModelTaskDefPtr()->task(); + for (int i = 0; i < tasks.size(); ++i) { + TaskDef &task_def = tasks[i]; + GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(), + task_def.DebugString().c_str()); + auto task_type = static_cast(task_def.type()); + if (task_type == RT_MODEL_TASK_KERNEL) { + const auto &kernel_def = task_def.kernel(); + const auto &context = kernel_def.context(); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == ccKernelType::TE) { + tbe_tasks_.emplace_back(task_def); + } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { + aicpu_tasks_.emplace_back(task_def); + } else { + GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, + "[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", + context.kernel_type()); + REPORT_INNER_ERROR("E19999", + "BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.", + context.kernel_type()); + return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; + } + } else if (task_type == RT_MODEL_TASK_ALL_KERNEL) { + tbe_tasks_.emplace_back(task_def); + } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { + aicpu_tasks_.emplace_back(task_def); } else { // skip GELOGD("Skip task type: %d", static_cast(task_type)); @@ -654,6 +636,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); model_params_.memory_size = UINT64_MAX; model_params_.graph_is_dynamic = true; + GE_CHK_STATUS_RET(ParseTasks(), "[Parse][Tasks] failed."); auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); diff --git a/ge/single_op/single_op_model.h b/ge/single_op/single_op_model.h index b7f6b42a..45616d9a 100755 --- a/ge/single_op/single_op_model.h +++ b/ge/single_op/single_op_model.h @@ -71,13 +71,16 @@ class SingleOpModel { Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task); Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id); Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id); - Status BuildModelTaskKernel(StreamResource *stream_resource, const domi::TaskDef &task_def, - DynamicSingleOp &single_op); static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); void ParseArgTable(OpTask *task, SingleOp &op); Status InitHybridModelExecutor(const StreamResource &resource, const GeModelPtr &ge_model, SingleOp &single_op); Status SetHostMemTensor(DynamicSingleOp &single_op); + Status NeedHybridModel(GeModelPtr &ge_model, bool &flag); + Status ParseTasks(); + + std::vector tbe_tasks_; + std::vector aicpu_tasks_; std::string model_name_; uint32_t model_id_ = 0; diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 0aa2b617..085bb5ff 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -33,6 +33,10 @@ #include "register/op_tiling.h" namespace ge { +namespace { +const int kAddressNum = 2; +} // namespace + class StreamResource; struct SingleOpModelParam; class OpTask { @@ -264,7 +268,7 @@ class MemcpyAsyncTask : public OpTask { friend class SingleOpModel; friend class RtsKernelTaskBuilder; - uintptr_t addresses_[2]; + uintptr_t addresses_[kAddressNum]; size_t dst_max_; size_t count_; rtMemcpyKind_t kind_; diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index db8ecfe2..c1bafed8 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -104,7 +104,7 @@ Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bi binary.version = 0; binary.data = kernel_bin.GetBinData(); binary.length = kernel_bin.GetBinDataSize(); - binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC; + GE_CHK_STATUS_RET_NOLOG(GetMagic(binary.magic)); Status ret = 0; if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) { ret = rtRegisterAllKernel(&binary, bin_handle); @@ -416,4 +416,27 @@ Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) { task.EnableDynamicSupport(node_, tiling_buffer, static_cast(max_size)); return SUCCESS; } + +Status TbeTaskBuilder::GetMagic(uint32_t &magic) const { + std::string json_string; + GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc_, TVM_ATTR_NAME_MAGIC, json_string), + GELOGD("Get original type of session_graph_id.")); + if (json_string == "RT_DEV_BINARY_MAGIC_ELF") { + magic = RT_DEV_BINARY_MAGIC_ELF; + } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") { + magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; + } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") { + magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE; + } else { + REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%s check invalid", + TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(), + op_desc_->GetType().c_str(), json_string.c_str()); + GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in op:%s(%s), value:%s check invalid", + TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(), + op_desc_->GetType().c_str(), json_string.c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + } // namespace ge diff --git a/ge/single_op/task/tbe_task_builder.h b/ge/single_op/task/tbe_task_builder.h index a202cbf1..6252feea 100755 --- a/ge/single_op/task/tbe_task_builder.h +++ b/ge/single_op/task/tbe_task_builder.h @@ -105,6 +105,7 @@ class TbeTaskBuilder { const SingleOpModelParam ¶m); Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const; Status DoRegisterMeta(void *bin_handle); + Status GetMagic(uint32_t &magic) const; static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); diff --git a/metadef b/metadef index 9e4a51a9..9c9907b7 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 9e4a51a9602195b82e326b853f5adbfefc3972b6 +Subproject commit 9c9907b76a457f456072af96b8cbcfb7943beccc diff --git a/parser b/parser index 79536a19..15a27afe 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit 79536a196f89cf7a1f5852ff7304b9a7d7b12eff +Subproject commit 15a27afefe45f2abdb78787d629163aab9437599 diff --git a/scripts/env/Dockerfile b/scripts/env/Dockerfile index af02f7bb..923a1453 100755 --- a/scripts/env/Dockerfile +++ b/scripts/env/Dockerfile @@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l ENV PROJECT_HOME=/code/Turing/graphEngine +RUN mkdir /var/run/sshd +RUN echo "root:root" | chpasswd +RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config +RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd + +ENV NOTVISIBLE "in users profile" +RUN echo "export VISIBLE=now" >> /etc/profile + +EXPOSE 22 7777 + +RUN useradd -ms /bin/bash debugger +RUN echo "debugger:ge123" | chpasswd + +CMD ["/usr/sbin/sshd" "-D" "&"] + RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc diff --git a/scripts/env/ge_env.sh b/scripts/env/ge_env.sh index 18c6aa5d..10ca810f 100755 --- a/scripts/env/ge_env.sh +++ b/scripts/env/ge_env.sh @@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} -DOCKER_IMAGE_TAG=ge_build_env.1.0.6 +DOCKER_IMAGE_TAG=ge_build_env.1.0.9 DOCKER_IAMGE_NAME=joycode2art/turing DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} @@ -61,7 +61,7 @@ function enter_docker_env(){ if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then echo "please run 'ge env --pull' to download images first!" elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then - $docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} + $docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then $docker_cmd start ${DOCKER_BUILD_ENV_NAME} $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} diff --git a/tests/depends/cce/CMakeLists.txt b/tests/depends/cce/CMakeLists.txt index 7550c63f..05fa8133 100644 --- a/tests/depends/cce/CMakeLists.txt +++ b/tests/depends/cce/CMakeLists.txt @@ -60,6 +60,7 @@ set(SRCS "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" + "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" diff --git a/tests/framework/CMakeLists.txt b/tests/framework/CMakeLists.txt index 8a2218b4..bbab454b 100644 --- a/tests/framework/CMakeLists.txt +++ b/tests/framework/CMakeLists.txt @@ -17,16 +17,3 @@ include(cmake/graphengine.cmake) add_subdirectory(easy_graph) add_subdirectory(ge_graph_dsl) add_subdirectory(ge_running_env) - -file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS - "utils/*.cc" - ) - -add_library(framework STATIC ${UTILS_SRC}) - -target_include_directories(framework - PUBLIC utils/ -) - -set_target_properties(framework PROPERTIES CXX_STANDARD 11) -target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env) diff --git a/tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h b/tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h index 4d430983..46bfe324 100644 --- a/tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h +++ b/tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h @@ -26,16 +26,32 @@ EG_NS_BEGIN //////////////////////////////////////////////////////////////// namespace detail { -template +template Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { GraphBuilder builder(name); builderInDSL(builder); return std::move(*builder); } + +struct GraphDefiner { + GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) { + name = specifiedName ? specifiedName : defaultName; + } + + template + auto operator|(USER_BUILDER &&userBuilder) { + GraphBuilder graphBuilder{name}; + std::forward(userBuilder)(graphBuilder); + return *graphBuilder; + } + + private: + const char *name; +}; + } // namespace detail -#define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__) -#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER) +#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER) #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ #define CHAIN(...) DATA_CHAIN(__VA_ARGS__) diff --git a/tests/framework/easy_graph/src/layout/graph_layout.cc b/tests/framework/easy_graph/src/layout/graph_layout.cc index 340acf67..716bed8a 100644 --- a/tests/framework/easy_graph/src/layout/graph_layout.cc +++ b/tests/framework/easy_graph/src/layout/graph_layout.cc @@ -16,10 +16,15 @@ #include "easy_graph/layout/graph_layout.h" #include "easy_graph/layout/layout_executor.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" #include "easy_graph/graph/graph.h" EG_NS_BEGIN +namespace { +GraphEasyExecutor default_executor; +} + void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { this->executor_ = &executor; options_ = opts; @@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { const LayoutOption *options = opts ? opts : this->options_; - if (!executor_) - return EG_UNIMPLEMENTED; + if (!executor_) return static_cast(default_executor).Layout(graph, options); return executor_->Layout(graph, options); } diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h new file mode 100644 index 00000000..7f5d5086 --- /dev/null +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef D52AA06185E34BBFB714FFBCDAB0D53A +#define D52AA06185E34BBFB714FFBCDAB0D53A + +#include "ge_graph_dsl/ge.h" +#include +#include + +GE_NS_BEGIN + +struct AssertError : std::exception { + AssertError(const char *file, int line, const std::string &info); + + private: + const char *what() const noexcept override; + + private: + std::string info; +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h new file mode 100644 index 00000000..fa0ae783 --- /dev/null +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_31309AA0A4E44C009C22AD9351BF3410 +#define INC_31309AA0A4E44C009C22AD9351BF3410 + +#include "ge_graph_dsl/ge.h" +#include "graph/compute_graph.h" + +GE_NS_BEGIN + +using GraphCheckFun = std::function; +struct CheckUtils { + static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun); + static void init(); +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/utils/builder/tensor_builder_utils.cc b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h similarity index 68% rename from tests/framework/utils/builder/tensor_builder_utils.cc rename to tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h index f99b9107..a208c02e 100644 --- a/tests/framework/utils/builder/tensor_builder_utils.cc +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h @@ -1,17 +1,32 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensor_builder_utils.h" +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef C8B32320BD4943D588594B82FFBF2685 +#define C8B32320BD4943D588594B82FFBF2685 + +#include +#include +#include "ge_graph_dsl/ge.h" + +GE_NS_BEGIN + +struct FilterScopeGuard { + FilterScopeGuard(const std::vector &); + ~FilterScopeGuard(); +}; + +GE_NS_END + +#endif diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h new file mode 100644 index 00000000..663907a0 --- /dev/null +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6 +#define AD954C4ADF5B44F5B1CC8BCD72EE9ED6 + +#include "ge_graph_dsl/ge.h" +#include "ge_graph_dsl/assert/check_utils.h" +#include "ge_graph_dsl/assert/assert_error.h" +#include "ge_graph_dsl/assert/filter_scope_guard.h" + +GE_NS_BEGIN + +#ifdef GTEST_MESSAGE_AT_ +#define GRAPH_CHECK_MESSAGE(file, line, message) \ + GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure) +#elif +#define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message) +#endif + +namespace detail { +struct GraphAssert { + GraphAssert(const char *file, unsigned int line, const std::string &phase_id) + : file_(file), line_(line), phase_id_(phase_id) {} + + void operator|(const ::GE_NS::GraphCheckFun &check_fun) { + bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun); + if (!ret) { + auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! "; + GRAPH_CHECK_MESSAGE(file_, line_, message.c_str()); + } + } + + private: + const char *file_; + unsigned int line_; + const std::string phase_id_; +}; +} // namespace detail + +#define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__}); +#define CHECK_GRAPH(phase_id) \ + ::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph) + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h index bb2326ec..99eafa7f 100644 --- a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h @@ -33,14 +33,12 @@ struct OpDescCfg { std::vector shape_; }; - OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW, + OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, std::vector shape = {1, 1, 224, 224}) : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} protected: - OpType GetType() const { - return type_; - } + OpType GetType() const { return type_; } OpType type_; int in_cnt_; int out_cnt_; diff --git a/tests/framework/ge_graph_dsl/src/assert/assert_error.cc b/tests/framework/ge_graph_dsl/src/assert/assert_error.cc new file mode 100644 index 00000000..5b74d852 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/assert_error.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ge_graph_dsl/assert/assert_error.h" + +GE_NS_BEGIN + +AssertError::AssertError(const char *file, int line, const std::string &info) { + this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info; +} + +const char *AssertError::what() const noexcept { return info.c_str(); } + +GE_NS_END diff --git a/tests/framework/ge_graph_dsl/src/assert/check_utils.cc b/tests/framework/ge_graph_dsl/src/assert/check_utils.cc new file mode 100644 index 00000000..56bc6e81 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/check_utils.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_graph_dsl/assert/check_utils.h" +#include "graph/utils/dumper/ge_graph_dumper.h" +#include "ge_graph_default_checker.h" +#include "ge_graph_check_dumper.h" + +GE_NS_BEGIN + +bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) { + auto &dumper = dynamic_cast(GraphDumperRegistry::GetDumper()); + return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun)); +} + +void CheckUtils::init() { + static GeGraphCheckDumper checkDumper; + GraphDumperRegistry::Register(checkDumper); +} + +GE_NS_END diff --git a/tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc b/tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc new file mode 100644 index 00000000..4aa4795d --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_graph_dsl/assert/filter_scope_guard.h" +#include "graph/utils/dumper/ge_graph_dumper.h" +#include "ge_dump_filter.h" + +GE_NS_BEGIN + +namespace { +GeDumpFilter &GetDumpFilter() { return dynamic_cast(GraphDumperRegistry::GetDumper()); } +} // namespace + +FilterScopeGuard::FilterScopeGuard(const std::vector &filter) { GetDumpFilter().Update(filter); } + +FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); } + +GE_NS_END \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h b/tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h new file mode 100644 index 00000000..47967c91 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6 +#define INC_4C6224E8F7474EF89B18CCB0E4B19FD6 + +#include +#include +#include "ge_graph_dsl/ge.h" +#include "easy_graph/infra/keywords.h" + +GE_NS_BEGIN + +INTERFACE(GeDumpFilter) { + ABSTRACT(void Update(const std::vector &)); + ABSTRACT(void Reset()); +}; + +GE_NS_END + +#endif diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc b/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc new file mode 100644 index 00000000..ba72cf86 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_graph_check_dumper.h" +#include "graph/model.h" +#include "graph/buffer.h" +#include "graph/utils/graph_utils.h" +#include "ge_graph_default_checker.h" + +GE_NS_BEGIN + +GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); } + +bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const { + auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix); + return (iter != suffixes_.end()); +} + +void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) { + if (!IsNeedDump(suffix)) { + return; + } + auto iter = buffers_.find(suffix); + if (iter != buffers_.end()) { + DumpGraph(graph, iter->second); + } else { + buffers_[suffix] = Buffer(); + DumpGraph(graph, buffers_.at(suffix)); + } +} + +bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) { + auto iter = buffers_.find(checker.PhaseId()); + if (iter == buffers_.end()) { + return false; + } + DoCheck(checker, iter->second); + return true; +} + +void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) { + Model model("", ""); + Model::Load(buffer.GetData(), buffer.GetSize(), model); + auto load_graph = model.GetGraph(); + checker.Check(GraphUtils::GetComputeGraph(load_graph)); +} + +void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) { + Model model("", ""); + buffer.clear(); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); + model.Save(buffer, true); +} + +void GeGraphCheckDumper::Update(const std::vector &new_suffixes_) { + suffixes_ = new_suffixes_; + buffers_.clear(); +} + +void GeGraphCheckDumper::Reset() { + static std::vector default_suffixes_{"PreRunAfterBuild"}; + suffixes_ = default_suffixes_; + buffers_.clear(); +} + +GE_NS_END \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h b/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h new file mode 100644 index 00000000..5eda52ea --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_8EFED0015C27464897BF64531355C810 +#define INC_8EFED0015C27464897BF64531355C810 + +#include "ge_graph_dsl/ge.h" +#include "graph/utils/dumper/ge_graph_dumper.h" +#include "ge_dump_filter.h" +#include + +GE_NS_BEGIN + +struct GeGraphChecker; + +struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter { + GeGraphCheckDumper(); + virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix); + bool CheckFor(const GeGraphChecker &checker); + + private: + void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer); + void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer); + + private: + void Update(const std::vector &) override; + void Reset() override; + bool IsNeedDump(const std::string &suffix) const; + + private: + std::map buffers_; + std::vector suffixes_; +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h b/tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h new file mode 100644 index 00000000..c6b25b65 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_5960A8F437324904BEE0690271258762 +#define INC_5960A8F437324904BEE0690271258762 + +#include "ge_graph_dsl/ge.h" +#include "easy_graph/infra/keywords.h" +#include "graph/compute_graph.h" + +GE_NS_BEGIN + +INTERFACE(GeGraphChecker) { + ABSTRACT(const std::string &PhaseId() const); + ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const); +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc b/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc new file mode 100644 index 00000000..4aa48ac6 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_graph_default_checker.h" + +GE_NS_BEGIN + +GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun) + : phase_id_(phase_id), check_fun_(check_fun) {} + +const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; } + +void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); } + +GE_NS_END \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h b/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h new file mode 100644 index 00000000..af8f3fbe --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BCF4D96BE9FC48938DE7B7E93B551C54 +#define BCF4D96BE9FC48938DE7B7E93B551C54 + +#include "ge_graph_dsl/ge.h" +#include "ge_graph_checker.h" +#include "graph/compute_graph.h" + +GE_NS_BEGIN + +using GraphCheckFun = std::function; + +struct GeGraphDefaultChecker : GeGraphChecker { + GeGraphDefaultChecker(const std::string &, const GraphCheckFun &); + + private: + const std::string &PhaseId() const override; + void Check(const ge::ComputeGraphPtr &graph) const override; + + private: + const std::string phase_id_; + const GraphCheckFun check_fun_; +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/op_desc_cfg_box.cc b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc similarity index 100% rename from tests/framework/ge_graph_dsl/src/op_desc_cfg_box.cc rename to tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc diff --git a/tests/framework/ge_graph_dsl/src/op_desc_cfg_repo.cc b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc similarity index 53% rename from tests/framework/ge_graph_dsl/src/op_desc_cfg_repo.cc rename to tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc index e7fa018f..19dfa4a5 100644 --- a/tests/framework/ge_graph_dsl/src/op_desc_cfg_repo.cc +++ b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc @@ -23,15 +23,22 @@ GE_NS_BEGIN namespace { -#define OP_CFG(optype, ...) \ - { \ - optype, OpDescCfg { \ - optype, __VA_ARGS__ \ - } \ +#define OP_CFG(optype, ...) \ + { \ + optype, OpDescCfg { optype, __VA_ARGS__ } \ } static std::map cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}), + OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), OP_CFG(VARIABLE, 1, 1)}; } // namespace diff --git a/tests/framework/ge_graph_dsl/src/op_desc_ptr_box.cc b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc similarity index 97% rename from tests/framework/ge_graph_dsl/src/op_desc_ptr_box.cc rename to tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc index 23d4773c..1564e019 100644 --- a/tests/framework/ge_graph_dsl/src/op_desc_ptr_box.cc +++ b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc @@ -19,6 +19,4 @@ USING_GE_NS -OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { - return op_; -} +OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; } diff --git a/tests/framework/ge_graph_dsl/src/ge_graph_vistor.cc b/tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc similarity index 89% rename from tests/framework/ge_graph_dsl/src/ge_graph_vistor.cc rename to tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc index d8bc2aab..c1dca646 100644 --- a/tests/framework/ge_graph_dsl/src/ge_graph_vistor.cc +++ b/tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc @@ -36,17 +36,11 @@ GE_NS_BEGIN GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared("")) {} -void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { - build_graph_ = graph; -} +void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; } -Graph GeGraphVisitor::BuildGeGraph() const { - return GraphUtils::CreateGraphFromComputeGraph(build_graph_); -} +Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); } -ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { - return build_graph_; -} +ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; } Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { build_graph_->SetName(graph.GetName()); diff --git a/tests/framework/ge_graph_dsl/src/ge_subgraph_vistor.cc b/tests/framework/ge_graph_dsl/src/vistor/ge_subgraph_vistor.cc similarity index 100% rename from tests/framework/ge_graph_dsl/src/ge_subgraph_vistor.cc rename to tests/framework/ge_graph_dsl/src/vistor/ge_subgraph_vistor.cc diff --git a/tests/framework/ge_graph_dsl/src/graph_dsl.cc b/tests/framework/ge_graph_dsl/src/vistor/graph_dsl.cc similarity index 100% rename from tests/framework/ge_graph_dsl/src/graph_dsl.cc rename to tests/framework/ge_graph_dsl/src/vistor/graph_dsl.cc diff --git a/tests/framework/ge_graph_dsl/tests/CMakeLists.txt b/tests/framework/ge_graph_dsl/tests/CMakeLists.txt index 40097d8b..65482679 100644 --- a/tests/framework/ge_graph_dsl/tests/CMakeLists.txt +++ b/tests/framework/ge_graph_dsl/tests/CMakeLists.txt @@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE ) set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) -target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl) +target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl) include(CTest) enable_testing() diff --git a/tests/framework/ge_graph_dsl/tests/check_graph_test.cc b/tests/framework/ge_graph_dsl/tests/check_graph_test.cc new file mode 100644 index 00000000..731b7eed --- /dev/null +++ b/tests/framework/ge_graph_dsl/tests/check_graph_test.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" +#include "easy_graph/layout/graph_layout.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" +#include "ge_graph_dsl/graph_dsl.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/dumper/ge_graph_dumper.h" +#include "framework/common/types.h" +#include "ge_graph_dsl/assert/graph_assert.h" +#include "graph/model.h" +#include "graph/buffer.h" + +USING_GE_NS + +class CheckGraphTest : public testing::Test { + private: + EG_NS::GraphEasyExecutor executor; + + protected: + void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } + void TearDown() {} +}; + +TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + + DUMP_GRAPH_WHEN("after_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); + + CHECK_GRAPH(after_build) { + ASSERT_EQ(graph->GetName(), "g1"); + ASSERT_EQ(graph->GetAllNodesSize(), 2); + }; +} + +TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + DEF_GRAPH(g2) { + CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); + CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); + }; + + DUMP_GRAPH_WHEN("before_build", "after_build"); + + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build"); + + CHECK_GRAPH(before_build) { + ASSERT_EQ(graph->GetName(), "g1"); + ASSERT_EQ(graph->GetAllNodesSize(), 2); + }; + + CHECK_GRAPH(after_build) { + ASSERT_EQ(graph->GetName(), "g2"); + ASSERT_EQ(graph->GetAllNodesSize(), 3); + }; +} + +TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + DEF_GRAPH(g2) { + CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); + CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); + }; + + DUMP_GRAPH_WHEN("before_build") + + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build"); + + CHECK_GRAPH(before_build) { + ASSERT_EQ(graph->GetName(), "g2"); + ASSERT_EQ(graph->GetAllNodesSize(), 3); + }; +} + +TEST_F(CheckGraphTest, test_check_phases_is_work) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + + DUMP_GRAPH_WHEN("before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); + auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {}); + ASSERT_FALSE(ret); +} + +TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + + DUMP_GRAPH_WHEN("before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); + + CHECK_GRAPH(before_build) { + ASSERT_EQ(graph->GetName(), "g1"); + ASSERT_EQ(graph->GetAllNodesSize(), 2); + }; +} + +TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + auto ge_graph = ToGeGraph(g1); + + ge::Model model("", ""); + model.SetGraph(ge_graph); + Buffer buffer; + model.Save(buffer, true); + + ge::Model loadModel("", ""); + Model::Load(buffer.GetData(), buffer.GetSize(), loadModel); + auto load_graph = loadModel.GetGraph(); + + ASSERT_EQ(load_graph.GetName(), "g1"); + ASSERT_EQ(load_graph.GetAllNodes().size(), 2); +} diff --git a/tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc b/tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc index f7e55e3d..a8240b32 100644 --- a/tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc +++ b/tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc @@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test { EG_NS::GraphEasyExecutor executor; protected: - void SetUp() { - EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); - } + void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } void TearDown() {} }; TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { - DEF_GRAPH(g1) { - CHAIN(NODE("data1", DATA)->NODE("add", ADD)); - }); + DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; auto geGraph = ToGeGraph(g1); auto computeGraph = ToComputeGraph(g1); @@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { } TEST_F(GraphDslTest, test_build_graph_with_name) { - DEF_GRAPH(g1, "sample_graph") { - CHAIN(NODE("data1", DATA)->NODE("add", ADD)); - }); + DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; auto geGraph = ToGeGraph(g1); @@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) { auto data = std::make_shared("data1", DATA); auto add = std::make_shared("Add", ADD); CHAIN(NODE(data)->NODE(add)); - }); + }; auto geGraph = ToGeGraph(g1); @@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); - }); + }; auto geGraph = ToGeGraph(g1); @@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { } TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { - DEF_GRAPH(g1) { - CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); - }); + DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); }; auto geGraph = ToGeGraph(g1); @@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { } TEST_F(GraphDslTest, test_build_from_control_chain) { - DEF_GRAPH(g1) { - CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); - }); + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; auto geGraph = ToGeGraph(g1); @@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) { } TEST_F(GraphDslTest, test_build_from_data_chain) { - DEF_GRAPH(g1) { - DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); - }); + DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; auto geGraph = ToGeGraph(g1); @@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) { DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); - }); + }; auto geGraph = ToGeGraph(g1); @@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) { DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); - }); + }; auto geGraph = ToGeGraph(g1); @@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) { DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data2", DATA)->NODE("add")); - }); + }; auto geGraph = ToGeGraph(g1); @@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) { ->NODE("Add4") ->NODE("Add5") ->NODE("net_output", NETOUTPUT)); - }); + }; auto geGraph = ToGeGraph(g1); @@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) { CHAIN(NODE("add2")->NODE("net_output")); CHAIN(NODE("add3")->NODE("net_output")); CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); - }); + }; auto geGraph = ToGeGraph(g1); @@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) { DEF_GRAPH(sub_1) { CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); - }); + }; DEF_GRAPH(sub_2) { CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); - }); + }; DEF_GRAPH(g1) { CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("data_i", DATA)->NODE("while")); - }); + }; sub_1.Layout(); sub_2.Layout(); diff --git a/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc b/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc index 071f8c36..533e8198 100644 --- a/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc +++ b/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc @@ -30,5 +30,11 @@ REGISTER_OPTYPE_DEFINE(MUL, "Mul"); REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); REGISTER_OPTYPE_DEFINE(ADD, "Add"); REGISTER_OPTYPE_DEFINE(WHILE, "While"); +REGISTER_OPTYPE_DEFINE(ENTER, "Enter"); +REGISTER_OPTYPE_DEFINE(MERGE, "Merge"); +REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond"); +REGISTER_OPTYPE_DEFINE(SWITCH, "Switch"); +REGISTER_OPTYPE_DEFINE(EXIT, "Exit"); +REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration"); GE_NS_END diff --git a/tests/framework/utils/builder/tensor_builder_utils.h b/tests/framework/ge_graph_dsl/tests/test_main.cc similarity index 73% rename from tests/framework/utils/builder/tensor_builder_utils.h rename to tests/framework/ge_graph_dsl/tests/test_main.cc index 73656e4a..eb6112f2 100644 --- a/tests/framework/utils/builder/tensor_builder_utils.h +++ b/tests/framework/ge_graph_dsl/tests/test_main.cc @@ -1,22 +1,25 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H -#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H - -class tensor_builder_utils {}; - -#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "ge_graph_dsl/assert/check_utils.h" + +int main(int argc, char **argv) { + ::GE_NS::CheckUtils::init(); + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/tests/framework/utils/builder/graph_builder_utils.cc b/tests/framework/utils/builder/graph_builder_utils.cc deleted file mode 100644 index c5555235..00000000 --- a/tests/framework/utils/builder/graph_builder_utils.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph_builder_utils.h" -#include "inc/external/graph/operator.h" -#include "inc/external/graph/operator_factory.h" -#include "graph/utils/graph_utils.h" - -namespace ge { -namespace st { -NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, - Format format, DataType data_type, std::vector shape) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape(std::move(shape))); - tensor_desc->SetFormat(format); - tensor_desc->SetDataType(data_type); - - auto op_desc = std::make_shared(name, type); - for (int i = 0; i < in_cnt; ++i) { - op_desc->AddInputDesc(tensor_desc->Clone()); - } - for (int i = 0; i < out_cnt; ++i) { - op_desc->AddOutputDesc(tensor_desc->Clone()); - } - - return graph_->AddNode(op_desc); -} -void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { - GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); -} -void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { - GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); -} -} // namespace st -} // namespace ge diff --git a/tests/framework/utils/builder/graph_builder_utils.h b/tests/framework/utils/builder/graph_builder_utils.h deleted file mode 100644 index 4627f082..00000000 --- a/tests/framework/utils/builder/graph_builder_utils.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H -#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H - -#include -#include - -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "graph/graph.h" -#include "graph/node.h" - -namespace ge { -namespace st { -class ComputeGraphBuilder { - public: - explicit ComputeGraphBuilder(const std::string &name) { - graph_ = std::make_shared(name); - } - NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, - Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, - std::vector shape = {1, 1, 224, 224}); - void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); - void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); - ComputeGraphPtr GetComputeGraph() { - graph_->TopologicalSorting(); - return graph_; - } - Graph GetGraph() { - graph_->TopologicalSorting(); - return GraphUtils::CreateGraphFromComputeGraph(graph_); - } - - private: - ComputeGraphPtr graph_; -}; -} // namespace st -} // namespace ge - -#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H diff --git a/tests/st/testcase/CMakeLists.txt b/tests/st/testcase/CMakeLists.txt index b3663708..56b3f41b 100644 --- a/tests/st/testcase/CMakeLists.txt +++ b/tests/st/testcase/CMakeLists.txt @@ -8,7 +8,7 @@ target_include_directories(graph_engine_test set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) -target_link_libraries(graph_engine_test PRIVATE gtest framework) +target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env) include(CTest) enable_testing() diff --git a/tests/st/testcase/test_framework_dummy.cc b/tests/st/testcase/test_framework_dummy.cc index 0abdd18b..8f13bb78 100644 --- a/tests/st/testcase/test_framework_dummy.cc +++ b/tests/st/testcase/test_framework_dummy.cc @@ -15,23 +15,12 @@ */ #include -#include #include "external/ge/ge_api.h" -#include "ge_running_env/fake_engine.h" #include "graph/debug/ge_attr_define.h" #include "framework/common/types.h" - -#include "builder/graph_builder_utils.h" #include "ge_running_env/ge_running_env_faker.h" - -#include "graph/operator_reg.h" -#include "graph/operator.h" -#define protected public -#define private public -#include "graph/utils/op_desc_utils.h" #include "ge_graph_dsl/graph_dsl.h" -#undef protected -#undef private +#include "ge_graph_dsl/assert/graph_assert.h" using namespace std; using namespace ge; @@ -57,76 +46,58 @@ namespace { * **/ Graph BuildV1ControlFlowGraph() { - // build graph - st::ComputeGraphBuilder graphBuilder("g1"); - auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1); - auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1); - ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); - auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1); - auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1); - auto less = graphBuilder.AddNode("less", LESS, 2, 1); - auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL); - auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2); - auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1); - auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1); - auto add = graphBuilder.AddNode("add", ADD, 2, 1); - auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1); - - auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1); - auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1); - ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); - auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1); - auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2); - auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1); - auto mul = graphBuilder.AddNode("mul", MUL, 2, 1); - auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1); - auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1); - auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2); - // i = i+1 - graphBuilder.AddDataEdge(data_i, 0, enter_i, 0); - graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0); - graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1); - graphBuilder.AddDataEdge(merge_i, 0, less, 0); - graphBuilder.AddDataEdge(const_5, 0, less, 1); - graphBuilder.AddDataEdge(less, 0, loopcond, 0); - graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1); - graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0); - graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0); - graphBuilder.AddDataEdge(switch_i, 1, add, 0); - graphBuilder.AddDataEdge(const_1, 0, add, 1); - graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0); - graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1); - // a=a*2 - graphBuilder.AddDataEdge(data_a, 0, enter_a, 0); - graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0); - graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1); - graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1); - graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0); - graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0); - graphBuilder.AddDataEdge(switch_a, 1, mul, 0); - graphBuilder.AddDataEdge(const_2, 0, mul, 1); - graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0); - graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0); - // set const weight int64_t dims_size = 1; vector data_vec = {5}; for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); vector data_value_vec(dims_size, 1); GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); - GeTensorPtr data_tensor = - make_shared(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); - OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); - OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); - OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); + GeTensorPtr data_tensor = make_shared(data_tensor_desc, (uint8_t *)data_value_vec.data(), + data_value_vec.size() * sizeof(int32_t)); - return graphBuilder.GetGraph(); + auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1"); + auto const_op = OP_CFG(CONSTANT).Weight(data_tensor); + + DEF_GRAPH(g1) { + CHAIN(NODE("data_i", DATA) + ->NODE("enter_i", enter) + ->EDGE(0, 0) + ->NODE("merge_i", MERGE) + ->NODE("less", LESS) + ->NODE("loopcond", LOOPCOND)); + CHAIN(NODE("const_1", const_op) + ->EDGE(0, 1) + ->NODE("add", ADD) + ->NODE("iteration_i", NEXTITERATION) + ->EDGE(0, 1) + ->NODE("merge_i")); + CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less")); + CHAIN(NODE("loopcond") + ->EDGE(0, 1) + ->NODE("switch_i", SWITCH) + ->EDGE(0, 0) + ->NODE("exit_i", EXIT) + ->EDGE(0, 1) + ->NODE("netoutput", NETOUTPUT)); + CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add")); + CHAIN(NODE("data_a", DATA) + ->NODE("enter_a", enter) + ->NODE("merge_a", MERGE) + ->NODE("switch_a", SWITCH) + ->NODE("exit_a", EXIT) + ->EDGE(0, 0) + ->NODE("netoutput")); + CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a")); + CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL)); + CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a")); + }; + return ToGeGraph(g1); } } // namespace class FrameworkTest : public testing::Test { protected: + GeRunningEnvFaker ge_env; void SetUp() { ge_env.InstallDefault(); } void TearDown() {} - GeRunningEnvFaker ge_env; }; /// data data @@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) { DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data2", DATA)->NODE("add")); - }); + }; - auto graph = ToGeGraph(g1); - // new session & add graph map options; Session session(options); - auto ret = session.AddGraph(1, graph, options); - EXPECT_EQ(ret, SUCCESS); - // build input tensor + session.AddGraph(1, ToGeGraph(g1), options); std::vector inputs; - // build_graph through session - ret = session.BuildGraph(1, inputs); + auto ret = session.BuildGraph(1, inputs); + EXPECT_EQ(ret, SUCCESS); + CHECK_GRAPH(PreRunAfterBuild) { + ASSERT_EQ(graph->GetName(), "g1_1"); + ASSERT_EQ(graph->GetAllNodesSize(), 4); + }; } /** data a = 2; diff --git a/tests/st/testcase/test_ge_opt_info.cc b/tests/st/testcase/test_ge_opt_info.cc index 457473b1..2e8e5382 100644 --- a/tests/st/testcase/test_ge_opt_info.cc +++ b/tests/st/testcase/test_ge_opt_info.cc @@ -15,24 +15,12 @@ */ #include -#include "easy_graph/graph/box.h" -#include "easy_graph/graph/node.h" +#include "external/ge/ge_api.h" #include "easy_graph/builder/graph_dsl.h" -#include "easy_graph/builder/box_builder.h" -#include "easy_graph/layout/graph_layout.h" -#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" -#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" -#include "graph/graph.h" #include "graph/compute_graph.h" #include "framework/common/types.h" -#include "graph/debug/ge_attr_define.h" +#include "graph/ge_local_context.h" #include "ge_graph_dsl/graph_dsl.h" -#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" -#define protected public -#define private public -#include "ge_opt_info/ge_opt_info.h" -#undef private -#undef protected namespace ge { class STEST_opt_info : public testing::Test { @@ -52,7 +40,7 @@ TEST_F(STEST_opt_info, get_opt_info_all) { DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data2", DATA)->NODE("add")); - }); + }; auto graph = ToGeGraph(g1); @@ -95,7 +83,7 @@ TEST_F(STEST_opt_info, get_opt_info_success) { DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data2", DATA)->NODE("add")); - }); + }; auto graph = ToGeGraph(g1); diff --git a/tests/st/testcase/test_main.cc b/tests/st/testcase/test_main.cc index a39c68aa..a7a71954 100644 --- a/tests/st/testcase/test_main.cc +++ b/tests/st/testcase/test_main.cc @@ -15,9 +15,8 @@ */ #include - -#include "common/debug/log.h" #include "external/ge/ge_api.h" +#include "ge_graph_dsl/assert/check_utils.h" #include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" using namespace std; @@ -31,6 +30,7 @@ int main(int argc, char **argv) { std::cout << "ge init failed , ret code:" << init_status << endl; } GeRunningEnvFaker::BackupEnv(); + CheckUtils::init(); testing::InitGoogleTest(&argc, argv); int ret = RUN_ALL_TESTS(); return ret; diff --git a/tests/ut/common/graph/CMakeLists.txt b/tests/ut/common/graph/CMakeLists.txt index 73780967..ccf9ce5e 100644 --- a/tests/ut/common/graph/CMakeLists.txt +++ b/tests/ut/common/graph/CMakeLists.txt @@ -90,6 +90,7 @@ set(SRC_FILES "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" + "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 5b8958b4..d8fcd6c3 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -102,6 +102,7 @@ set(GRAPH_SRC_FILES "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" + "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" diff --git a/tests/ut/ge/graph/build/task_generator_unittest.cc b/tests/ut/ge/graph/build/task_generator_unittest.cc index f869f1e0..1e865050 100644 --- a/tests/ut/ge/graph/build/task_generator_unittest.cc +++ b/tests/ut/ge/graph/build/task_generator_unittest.cc @@ -116,7 +116,9 @@ TEST_F(UtestTaskGeneratorTest, FindLastBpFromBpNode) { TaskGenerator task_generator(nullptr, 0); auto net_output = graph->FindNode("Node_Output"); // netoutput has no data input, return default value 0 - EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output), 0); + uint32_t bp_index = 0; + EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output, bp_index), 0); + EXPECT_EQ(bp_index, 2); } TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) { diff --git a/tests/ut/ge/graph/passes/addn_pass_unittest.cc b/tests/ut/ge/graph/passes/addn_pass_unittest.cc index 6107a7d8..54a8819d 100644 --- a/tests/ut/ge/graph/passes/addn_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/addn_pass_unittest.cc @@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) { AddNPass *addn_pass = nullptr; NamesToPass names_to_pass; names_to_pass.emplace_back("Test", addn_pass); - EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); + EXPECT_NE(pass.Run(names_to_pass), SUCCESS); } TEST(UtestGraphPassesAddnPass, null_graph) { diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index c687e07f..42ef6e12 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -17,7 +17,6 @@ #include #include #include -#include #include #include "gtest/gtest.h" @@ -26,8 +25,6 @@ #include "graph/passes/base_pass.h" #undef protected -#include "external/graph/ge_error_codes.h" -#include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/node.h" #include "graph/utils/graph_utils.h" @@ -67,6 +64,54 @@ class UtestTestPass : public BaseNodePass { names_to_add_repass_.erase(iter); } } + + iter = names_to_add_repass_immediate_.find(node->GetName()); + if (iter != names_to_add_repass_immediate_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddImmediateRePassNode(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_repass_immediate_.erase(iter); + } + } + + iter = names_to_add_suspend_.find(node->GetName()); + if (iter != names_to_add_suspend_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddNodeSuspend(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_suspend_.erase(iter); + } + } + + iter = names_to_add_resume_.find(node->GetName()); + if (iter != names_to_add_resume_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddNodeResume(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_resume_.erase(iter); + } + } // simulate infershape pass if(node->GetType() == WHILE){ bool need_repass = false; @@ -85,6 +130,20 @@ class UtestTestPass : public BaseNodePass { } return SUCCESS; } + + Status OnSuspendNodesLeaked() override { + // resume all node remain in suspend_nodes when leaked + auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr; + if (compute_graph == nullptr) { + return SUCCESS; + } + + for (const auto &node_name : names_to_add_resume_onleaked_) { + auto node_to_resume = compute_graph->FindNode(node_name); + AddNodeResume(node_to_resume); + } + return SUCCESS; + } void clear() { iter_nodes_.clear(); } std::vector GetIterNodes() { return iter_nodes_; } @@ -94,12 +153,30 @@ class UtestTestPass : public BaseNodePass { void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { names_to_add_del_[iter_node].insert(del_node); } + void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) { + names_to_add_repass_immediate_[iter_node].insert(re_pass_node); + } + + void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) { + names_to_add_suspend_[iter_node].insert(suspend_node); + } + void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) { + names_to_add_resume_[iter_node].insert(resume_node); + } + void AddResumeNodeNameOnLeaked(const std::string &resume_node) { + names_to_add_resume_onleaked_.insert(resume_node); + } + unsigned int GetRunTimes() { return run_times_; } private: std::vector iter_nodes_; std::map> names_to_add_del_; std::map> names_to_add_repass_; + std::map> names_to_add_repass_immediate_; + std::map> names_to_add_suspend_; + std::map> names_to_add_resume_; + std::unordered_set names_to_add_resume_onleaked_; bool dead_loop_; unsigned int run_times_; }; @@ -200,6 +277,26 @@ ComputeGraphPtr BuildGraph3() { return builder.GetGraph(); } +/// cast1--shape1 +/// / +/// data1 +/// \ +/// transdata1--shape2 +ComputeGraphPtr BuildGraph4() { + auto builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", DATA, 0, 1); + auto cast1 = builder.AddNode("cast1", CAST, 1, 1); + auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); + auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1); + auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1); + + builder.AddDataEdge(data1, 0, cast1, 0); + builder.AddDataEdge(data1, 0, transdata1, 0); + builder.AddDataEdge(cast1, 0, shape1, 0); + builder.AddDataEdge(transdata1, 0, shape2, 0); + return builder.GetGraph(); +} + void CheckIterOrder(UtestTestPass *pass, std::vector> &nodes_layers) { std::unordered_set layer_nodes; size_t layer_index = 0; @@ -509,15 +606,369 @@ ComputeGraphPtr BuildWhileGraph1() { } TEST_F(UTESTGraphPassesBasePass, while_infershape) { -NamesToPass names_to_pass; -auto test_pass = UtestTestPass(); -names_to_pass.push_back(std::make_pair("test", &test_pass)); + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); -auto graph = BuildWhileGraph1(); -auto ge_pass = GEPass(graph); -auto while_node = graph->FindNode("while"); -EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); -EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + auto graph = BuildWhileGraph1(); + auto ge_pass = GEPass(graph); + auto while_node = graph->FindNode("while"); + EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass pre_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "add1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass cur_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "reshape1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass next_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass->AddRePassImmediateNodeName("add1", "sum1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 8); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} +/** + * A->B->C + * if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order + * when C resuem A, A will pass again. + */ +TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeName("sum1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + layers.push_back({"add1"}); + CheckIterOrder(test_pass, layers); } +/** + * A->B->C + * if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order + * when B resuem A, A will pass again. + */ +TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeName("reshape1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1", "add1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B resume C(which is not suspended), it is a useless operation, C will not pass. + */ +TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddResumeNodeName("reshape1", "sum1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 8); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order + * because nobody resume it ,which means A is a leaked node, so return fail + */ +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + // suspend pre_node immediately + test_pass.AddSuspendNodeName("reshape1", "add1"); + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR); +} + +/** + * A->B->C + * if node B suspend its pre_node A, it is a useless operation, + * so iter_order should follow normal order + * resume A on leaked, which means A will pass again + */ +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeNameOnLeaked("add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + layers.push_back({"add1"}); + CheckIterOrder(test_pass, layers); +} + + +/// cast1--shape1 +/// / +/// data1 +/// \ +/// transdata1--shape2 +/** + * suspend cur node + * cast1 suspend itself, shape2 resume cast1 + * iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1 + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeName("shape2", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1"}); + layers.push_back({"shape2"}); + layers.push_back({"cast1", "shape1"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend cur node + * cast1 suspend itself, then resume cast1 + * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeName("cast1", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend cur node + * cast1 suspend itself, then resume cast1 on leaked + * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeNameOnLeaked("cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1", "shape2"}); + layers.push_back({"cast1","shape1"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend next node + * data1 suspend cast1, then resume cast1 on leaked + * iter order follows : data1; transdata1, shape2; cast1, shape1. + */ +TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("data1", "cast1"); + test_pass->AddResumeNodeNameOnLeaked("cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 5); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"transdata1", "shape2"}); + layers.push_back({"cast1","shape1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * suspend next node + * data1 suspend cast1, nobody resume it + * iter order follows : data1; transdata1, shape2; + * run ret is failed ,because node leaked + */ +TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("data1", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR); + EXPECT_EQ(test_pass->GetIterNodes().size(), 3); +} + + +TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass pre_node immediately + test_pass.AddRePassImmediateNodeName("reshape1", "add1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "add1", "sum1"}); + CheckIterOrder(&test_pass, layers); +} +/// sum1 +/// / \. +/// / \. +/// / \. +/// reshape1 addn1 +/// | c | +/// add1 <--- shape1 +/// / \ | +/// | | | +/// data1 const1 const2 +TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass cur_node immediately + test_pass.AddRePassImmediateNodeName("reshape1", "reshape1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(&test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass next_node immediately + test_pass.AddRePassImmediateNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass.AddRePassImmediateNodeName("add1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(&test_pass, layers); +} +/* +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass next_node immediately + test_pass.AddRePassNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass.AddRePassNodeName("add1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(&test_pass, layers); +}*/ } // namespace ge diff --git a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc index 576d679c..c39755b3 100644 --- a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc @@ -293,6 +293,9 @@ class AddKernel : public Kernel { } else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) { vector data_vec; auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); + if (input[0]->GetTensorDesc().GetShape().IsScalar()) { + data_num = 1; + } auto x1_data = reinterpret_cast(input[0]->GetData().data()); auto x2_data = reinterpret_cast(input[1]->GetData().data()); for (size_t i = 0; i < data_num; i++) { @@ -410,6 +413,48 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave EXPECT_EQ(unknown_target_value_range, output_value_range); } +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_ScalarOutput) { + // shape --- add --- sqrt + // constant / + auto graph = std::make_shared("test_graph"); + vector data_vec = {2}; + GeTensorDesc const_td(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + GeTensorPtr const_tensor = std::make_shared(const_td, (uint8_t *)data_vec.data(), sizeof(int32_t)); + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_td); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + std::vector> known_value_range = {make_pair(1, 100)}; + shape_td.SetValueRange(known_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_td); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_td); + add_op_desc->AddInputDesc(const_td); + add_op_desc->AddOutputDesc(add_td); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 1); + + std::vector target_value_range = {3, 102}; + std::vector output_value_range = {out_value_range[0].first, out_value_range[0].second}; + EXPECT_EQ(output_value_range, target_value_range); +} + TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { // shape --- add --- sqrt // constant / diff --git a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc index 13e66c50..080894ee 100644 --- a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc @@ -29,13 +29,77 @@ using namespace std; using namespace testing; namespace ge { +namespace { +// do nothing stub infer_func +const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; +// infer from input to output stub infer_func (input size == output size) +const auto stub_mapping_func = [](Operator &op) { + size_t in_num = op.GetInputsSize(); + for (size_t i = 0; i < in_num; ++i) { + auto in_desc = op.GetInputDesc(i); + auto out_desc = op.GetOutputDesc(i); + out_desc.SetShape(in_desc.GetShape()); + out_desc.SetDataType(in_desc.GetDataType()); + op.UpdateOutputDesc(out_desc.GetName(), out_desc); + } + return GRAPH_SUCCESS; +}; +// merge infer_func + +// while infer_func +const auto while_infer_func = [](Operator &op) { + size_t in_num = op.GetInputsSize(); + size_t out_num = op.GetOutputsSize(); + if (in_num != out_num) { + return GRAPH_FAILED; + } + bool need_infer_again = false; + for (size_t i = 0; i < in_num; ++i) { + auto in_desc = op.GetDynamicInputDesc("input", i); + auto out_desc = op.GetDynamicOutputDesc("output", i); + auto data_shape = in_desc.GetShape(); + auto out_shape = out_desc.GetShape(); + if(out_shape.GetDims() == DUMMY_SHAPE){ + return GRAPH_SUCCESS; + } + // check datatype between output and input + if (in_desc.GetDataType() != out_desc.GetDataType()) { + return GRAPH_FAILED; + } + + if (data_shape.GetDims() != out_shape.GetDims()) { + need_infer_again = true; + if (data_shape.GetDimNum() != out_shape.GetDimNum()) { + in_desc.SetUnknownDimNumShape(); + } else { + size_t data_dim_num = data_shape.GetDimNum(); + std::vector> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)}; + for (size_t j = 0; j < data_dim_num; ++j) { + if (data_shape.GetDim(j) != out_shape.GetDim(j)) { + data_shape.SetDim(j, UNKNOWN_DIM); + } + if (data_shape.GetDim(j) != UNKNOWN_DIM) { + data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)); + } + } + in_desc.SetShape(data_shape); + in_desc.SetShapeRange(data_shape_range); + } + op.UpdateDynamicOutputDesc("output", i, in_desc); + op.UpdateDynamicInputDesc("input", i, in_desc); + } + } + return need_infer_again ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; +}; +} class UtestGraphInfershapePass : public testing::Test { protected: void SetUp() {} void TearDown() {} }; -static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num, + std::function infer_func = stub_func) { OpDescPtr op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); static int32_t index = 0; @@ -61,14 +125,11 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string op_desc->SetWorkspaceBytes({}); op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); - const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; - op_desc->AddInferFunc(stub_func); - op_desc->AddInferFormatFunc(stub_func); - op_desc->AddVerifierFunc(stub_func); - + op_desc->AddInferFunc(infer_func); return graph.AddNode(op_desc); } +/* TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); string type = "AddN"; @@ -82,6 +143,7 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { InferShapePass infershape_pass; EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); } +*/ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { auto graph = std::make_shared("test"); @@ -94,7 +156,43 @@ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); } +TEST_F(UtestGraphInfershapePass, infer_from_pre_to_next) { + /* + * cast->shape + */ + auto graph = std::make_shared("test_infer_shape"); + auto data1 = CreateNode(*graph, "dataq", DATA, 0, 1); + auto cast1 = CreateNode(*graph, "cast1", CAST, 1, 1, stub_mapping_func); + auto cast_in_desc = cast1->GetOpDesc()->MutableInputDesc(0); + cast_in_desc->SetShape(GeShape({1,2,3})); + cast_in_desc->SetDataType(DT_INT32); + auto transdata1 = CreateNode(*graph, "transdata1", TRANSDATA, 1, 1, stub_mapping_func); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), cast1->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast1->GetOutDataAnchor(0), transdata1->GetInDataAnchor(0)); + // check before infer cast1 + auto cast_before = graph->FindNode("cast1"); + vector expect_cast1_shape_dim = {1,2,3}; + auto real_cast1_before_shape_dim = cast_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(); + auto transdata1_before = graph->FindNode("transdata1"); + vector expect_transdata1_shape_dim = {}; + auto real_transdata1_before_shape_dim = transdata1_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(); + EXPECT_EQ(real_cast1_before_shape_dim, expect_cast1_shape_dim); + EXPECT_EQ(real_transdata1_before_shape_dim, expect_transdata1_shape_dim); + // run infershape pass + InferShapePass infer_shape_pass; + infer_shape_pass.Run(cast_before); + // check cast1 add transdata1 to repass_immediately + infer_shape_pass.GetNodesNeedRePassImmediately(); + EXPECT_TRUE(!infer_shape_pass.GetNodesNeedRePassImmediately().empty()); + // check transdata input_shape & datatype after infer + auto transdata1_after = graph->FindNode("transdata1"); + auto transdata1_opdesc = transdata1_before->GetOpDesc(); + auto real_transdata1_after_shape_dim = transdata1_opdesc->GetInputDesc(0).GetShape().GetDims(); + EXPECT_EQ(real_transdata1_after_shape_dim, expect_cast1_shape_dim); + auto transdata1_datatype_after = transdata1_opdesc->GetInputDesc(0).GetDataType(); + EXPECT_EQ(transdata1_datatype_after, DT_INT32); +} TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { /******************************************************************************* * Exit Identify diff --git a/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc b/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc index f772af23..c053885f 100644 --- a/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc @@ -103,4 +103,32 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute) { context.callback_manager->callback_queue_.Push(eof_entry); ASSERT_EQ(executor.Execute(args), SUCCESS); } + +TEST_F(UtestHybridModelAsyncExecutor, test_PrepareInputs) { + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = make_shared(graph); + ge_root_model->SetModelName("test_name"); + GeModelPtr ge_sub_model = make_shared(); + HybridModel hybrid_model(ge_root_model); + HybridModelAsyncExecutor executor(&hybrid_model); + GeTensorDescPtr tensor_desc = make_shared(GeShape({-1, 16, 16, 3})); + tensor_desc->SetShapeRange({{1, 256}, {16, 16}, {16, 16}, {3, 3}}); + executor.input_tensor_desc_.insert({0, tensor_desc}); + executor.device_id_ = 0; + executor.input_sizes_.insert({0, -1}); + executor.is_input_dynamic_.push_back(true); + + unique_ptr data_buf(new (std::nothrow)uint8_t[3072]); + InputData input_data; + input_data.blobs.push_back(DataBuffer(data_buf.get(), 3072, false)); + input_data.shapes.push_back({1, 16, 16, 3}); + HybridModelExecutor::ExecuteArgs args; + + auto ret = executor.PrepareInputs(input_data, args); + ASSERT_EQ(ret, SUCCESS); + ASSERT_EQ(args.input_desc[0]->GetShape().ToString(), GeShape({1, 16, 16, 3}).ToString()); + int64_t tensor_size = 0; + TensorUtils::GetSize(*(args.input_desc[0]), tensor_size); + ASSERT_EQ(tensor_size, 3104); +} } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc index 07701f4d..96641c59 100644 --- a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc +++ b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc @@ -27,6 +27,7 @@ #include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/executor/worker/execution_engine.h" #include "hybrid/executor/subgraph_executor.h" +#include "hybrid/executor/worker/task_compile_engine.h" #undef private #undef protected @@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test { }; namespace { const int kIntBase = 10; +class CompileNodeExecutor : public NodeExecutor { + public: + Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const override { + return SUCCESS; + } +}; } + static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { auto op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); @@ -128,4 +136,8 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { executor.InitCallback(node_state.get(), callback); ExecutionEngine execution_engine; EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); + + CompileNodeExecutor node_executor; + node_item->node_executor = &node_executor; + EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS); } diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 1d1c4fa9..d1c51c67 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -153,6 +153,7 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { ge::AttrUtils::SetStr(op_desc, "compile_info_json", "json"); ge::AttrUtils::SetBool(op_desc, "support_dynamicshape", true); ge::AttrUtils::SetInt(op_desc, "op_para_size", 1); + ge::AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF"); auto node = graph->AddNode(op_desc); std::unique_ptr node_item; diff --git a/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc index 8a1240d3..1d5bbb3d 100644 --- a/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc @@ -87,6 +87,7 @@ TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) { TEST_F(NodeExecutorTest, TestInitAndFinalize) { auto &manager = NodeExecutorManager::GetInstance(); manager.FinalizeExecutors(); + manager.FinalizeExecutors(); manager.EnsureInitialized(); manager.EnsureInitialized(); const NodeExecutor *executor = nullptr; diff --git a/tests/ut/ge/single_op/single_op_model_unittest.cc b/tests/ut/ge/single_op/single_op_model_unittest.cc index e4a53340..23269814 100644 --- a/tests/ut/ge/single_op/single_op_model_unittest.cc +++ b/tests/ut/ge/single_op/single_op_model_unittest.cc @@ -311,7 +311,7 @@ TEST_F(UtestSingleOpModel, BuildTaskList) { ASSERT_EQ(mem_task.LaunchKernel(0), SUCCESS); } -TEST_F(UtestSingleOpModel, build_aicpu_task) { +TEST_F(UtestSingleOpModel, build_dynamic_task) { ComputeGraphPtr graph = make_shared("single_op"); GeModelPtr ge_model = make_shared(); ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); @@ -321,6 +321,15 @@ TEST_F(UtestSingleOpModel, build_aicpu_task) { domi::TaskDef *task_def = model_task_def->add_task(); task_def->set_type(RT_MODEL_TASK_KERNEL_EX); + domi::TaskDef *task_def2 = model_task_def->add_task(); + task_def2->set_type(RT_MODEL_TASK_KERNEL); + domi::KernelDef *kernel_def = task_def2->mutable_kernel(); + domi::KernelContext *context = kernel_def->mutable_context(); + context->set_kernel_type(6); // ccKernelType::AI_CPU + + domi::TaskDef *task_def3 = model_task_def->add_task(); + task_def3->set_type(RT_MODEL_TASK_ALL_KERNEL); + string model_data_str = "123456789"; SingleOpModel model("model", model_data_str.c_str(), model_data_str.size()); std::mutex stream_mu; @@ -329,8 +338,18 @@ TEST_F(UtestSingleOpModel, build_aicpu_task) { DynamicSingleOp single_op(0, &stream_mu, stream); model.model_helper_.model_ = ge_model; auto op_desc = std::make_shared("add", "Add"); + AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF"); + std::vector kernelBin; + TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); + op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); NodePtr node = graph->AddNode(op_desc); model.op_list_[0] = node; StreamResource *res = new (std::nothrow) StreamResource(1); + + ASSERT_EQ(model.ParseTasks(), SUCCESS); + ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS); + model.tbe_tasks_.clear(); ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS); + model.aicpu_tasks_[0] = *task_def2; + model.BuildTaskListForDynamicOp(res, single_op); } diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index b0c98205..2424d209 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -54,6 +54,7 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { auto graph = make_shared("graph"); auto op_desc = make_shared("Add", "Add"); + AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF"); std::vector kernelBin; TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); diff --git a/third_party/fwkacllib/inc/external/runtime/rt_error_codes.h b/third_party/fwkacllib/inc/external/runtime/rt_error_codes.h index 67146dbe..9f216a56 100755 --- a/third_party/fwkacllib/inc/external/runtime/rt_error_codes.h +++ b/third_party/fwkacllib/inc/external/runtime/rt_error_codes.h @@ -38,6 +38,7 @@ static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callba static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type +static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error @@ -50,6 +51,7 @@ static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no eve static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource +static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error @@ -85,9 +87,14 @@ static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error +static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout +static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception +static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception +static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error +static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect #ifdef __cplusplus } diff --git a/third_party/fwkacllib/inc/runtime/base.h b/third_party/fwkacllib/inc/runtime/base.h index 40bc91f7..7fc1cdea 100644 --- a/third_party/fwkacllib/inc/runtime/base.h +++ b/third_party/fwkacllib/inc/runtime/base.h @@ -156,7 +156,7 @@ RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtSt /** * @ingroup profiling_base - * @brief ts send keypoint for step info. + * @brief ts send keypoint profiler log. */ RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream); @@ -206,7 +206,7 @@ RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCal /** * @ingroup dvrt_base - * @brief register callback for fail task + * @brief register callback for fail task * @param [in] uniName unique register name, can't be null * @param [in] callback fail task callback function * @param [out] NA @@ -345,11 +345,11 @@ RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream); * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input */ -rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream); +RTS_API rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream); /** * @ingroup dvrt_base - * @brief get current thread last stream id and task id + * @brief get current thread last stream id and task id * @param [out] stream id and task id * @param [in] null * @return RT_ERROR_NONE for ok diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index ee104693..c1327c45 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -46,6 +46,12 @@ typedef enum tagRtChipType { CHIP_END, } rtChipType_t; +typedef enum tagRtAicpuScheType { + SCHEDULE_SOFTWARE = 0, /* Software Schedule */ + SCHEDULE_SOFTWARE_OPT, + SCHEDULE_HARDWARE, /* HWTS Schedule */ +} rtAicpuScheType; + typedef enum tagRtVersion { VER_BEGIN = 0, VER_NA = VER_BEGIN, @@ -65,6 +71,7 @@ typedef enum tagRtPlatformType { PLATFORM_LHISI_CS, PLATFORM_DC, PLATFORM_CLOUD_V2, + PLATFORM_LHISI_SD3403, PLATFORM_END, } rtPlatformType_t; @@ -126,6 +133,11 @@ typedef struct tagRtPlatformConfig { uint32_t platformConfig; } rtPlatformConfig_t; +typedef enum tagRTTaskTimeoutType { + RT_TIMEOUT_TYPE_OP_WAIT = 0, + RT_TIMEOUT_TYPE_OP_EXECUTE, +} rtTaskTimeoutType_t; + /** * @ingroup * @brief get AI core count @@ -184,6 +196,37 @@ RTS_API rtError_t rtMemGetL2Info(rtStream_t stream, void **ptr, uint32_t *size); */ RTS_API rtError_t rtGetRuntimeVersion(uint32_t *runtimeVersion); + +/** + * @ingroup + * @brief get device feature ability by device id, such as task schedule ability. + * @param [in] deviceId + * @param [in] moduleType + * @param [in] featureType + * @param [out] value + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDeviceCapability(int32_t deviceId, int32_t moduleType, int32_t featureType, int32_t *value); + +/** + * @ingroup + * @brief set event wait task timeout time. + * @param [in] timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetOpWaitTimeOut(uint32_t timeout); + +/** + * @ingroup + * @brief set op execute task timeout time. + * @param [in] timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetOpExecuteTimeOut(uint32_t timeout); + #if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) } #endif diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index e82ec5fa..2cf6712f 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -63,6 +63,11 @@ typedef enum tagRtFeatureType { FEATURE_TYPE_RSV } rtFeatureType_t; +typedef enum tagRtDeviceFeatureType { + FEATURE_TYPE_SCHE, + FEATURE_TYPE_END, +} rtDeviceFeatureType_t; + typedef enum tagMemcpyInfo { MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, MEMCPY_INFO_RSV diff --git a/third_party/fwkacllib/inc/runtime/event.h b/third_party/fwkacllib/inc/runtime/event.h index 41e611ea..57948c47 100644 --- a/third_party/fwkacllib/inc/runtime/event.h +++ b/third_party/fwkacllib/inc/runtime/event.h @@ -23,12 +23,23 @@ extern "C" { #endif +typedef enum rtEventWaitStatus { + EVENT_STATUS_COMPLETE = 0, + EVENT_STATUS_NOT_READY = 1, + EVENT_STATUS_MAX = 2, +} rtEventWaitStatus_t; + /** * @ingroup event_flags * @brief event op bit flags */ -#define RT_EVENT_DEFAULT (0x00) -#define RT_EVENT_WITH_FLAG (0x01) +#define RT_EVENT_DEFAULT (0x0E) +#define RT_EVENT_WITH_FLAG (0x0B) + +#define RT_EVENT_DDSYNC_NS 0x01U +#define RT_EVENT_STREAM_MARK 0x02U +#define RT_EVENT_DDSYNC 0x04U +#define RT_EVENT_TIME_LINE 0x08U /** * @ingroup dvrt_event @@ -104,6 +115,16 @@ RTS_API rtError_t rtEventSynchronize(rtEvent_t event); */ RTS_API rtError_t rtEventQuery(rtEvent_t event); +/** + * @ingroup dvrt_event + * @brief Queries an event's wait status + * @param [in] event event to query + * @param [in out] EVENT_WAIT_STATUS status + * @return EVENT_STATUS_COMPLETE for complete + * @return EVENT_STATUS_NOT_READY for not complete + */ +RTS_API rtError_t rtEventQueryWaitStatus(rtEvent_t event, rtEventWaitStatus_t *status); + /** * @ingroup dvrt_event * @brief computes the elapsed time between events. @@ -176,6 +197,18 @@ RTS_API rtError_t rtNotifyRecord(rtNotify_t notify, rtStream_t stream); */ RTS_API rtError_t rtNotifyWait(rtNotify_t notify, rtStream_t stream); +/** + * @ingroup dvrt_event + * @brief Wait for a notify with time out + * @param [in] notify_ notify to be wait + * @param [in] stream_ input stream + * @param [in] timeOut input timeOut + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx + */ +RTS_API rtError_t rtNotifyWaitWithTimeOut(rtNotify_t notify_, rtStream_t stream_, uint32_t timeOut); + /** * @ingroup dvrt_event * @brief Name a notify diff --git a/third_party/fwkacllib/inc/runtime/kernel.h b/third_party/fwkacllib/inc/runtime/kernel.h index 402fadef..9b0221c7 100644 --- a/third_party/fwkacllib/inc/runtime/kernel.h +++ b/third_party/fwkacllib/inc/runtime/kernel.h @@ -111,6 +111,16 @@ typedef struct rtKernelInfo { uint32_t module_size; } *rtKernelInfo_t; +/** + * @ingroup rt_kernel + * @brief op name + */ +typedef struct rtKernelLaunchNames { + const char *soName; // defined for so name + const char *kernelName; // defined for kernel type name + const char *opName; // defined for operator name +} rtKernelLaunchNames_t; + /** * @ingroup rt_KernelConfigDump * @brief device dump type @@ -173,13 +183,7 @@ typedef void (*rtCallback_t)(void *fnData); * @ingroup rt_kernel * @brief magic number of elf binary for aicube */ -#define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41415247 - -/** - * @ingroup rt_kernel - * @brief magic number of elf binary for aivector - */ -#define RT_DEV_BINARY_MAGIC_ELF_AIVECTOR 0x41415248 +#define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41494343 /** * @ingroup rt_kernel_flags @@ -192,14 +196,14 @@ typedef void (*rtCallback_t)(void *fnData); #define RT_KERNEL_CUSTOM_AICPU (0x08) // STARS topic scheduler sqe : topic_type -#define RT_KERNEL_DEVICE_FIRST (0X10) -#define RT_KERNEL_HOST_ONLY (0X20) -#define RT_KERNEL_HOST_FIRST (0X30) +#define RT_KERNEL_DEVICE_FIRST (0x10) +#define RT_KERNEL_HOST_ONLY (0x20) +#define RT_KERNEL_HOST_FIRST (0x40) /** * @ingroup rt_kernel * @brief kernel mode - */ +**/ #define RT_DEFAULT_KERNEL_MODE (0x00) #define RT_NORMAL_KERNEL_MODE (0x01) #define RT_ALL_KERNEL_MODE (0x02) @@ -222,7 +226,7 @@ RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **handle); /** * @ingroup rt_kernel - * @brief register device binary + * @brief register device binary with all kernel * @param [in] bin device binary description * @param [out] handle device binary handle * @return RT_ERROR_NONE for ok @@ -341,7 +345,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void * * @ingroup rt_kernel * @brief launch kernel with handle to device * @param [in] handle program - * @param [in] devFunc device function description + * @param [in] devFunc device function description. * @param [in] blockDim block dimentions * @param [in] args argments address for kernel function * @param [in] argsSize argements size @@ -352,7 +356,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void * * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelLaunchWithHandle(void *handle, const void *devFunc, uint32_t blockDim, void *args, uint32_t argsSize, - rtSmDesc_t *smDesc, rtStream_t stream, const void *kernelInfo); + rtSmDesc_t *smDesc, rtStream_t stream_, const void *kernelInfo); /** * @ingroup rt_kernel @@ -371,7 +375,7 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); /** - * @ingroup rt_kernel + * @ingroup rt_kernel(abandoned) * @brief launch kernel to device * @param [in] args argments address for kernel function * @param [in] argsSize argements size @@ -383,7 +387,21 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream); /** - * @ingroup rt_kernel + * @ingroup rt_kernel(in use) + * @brief launch kernel to device + * @param [in] opName opkernel name + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] flags launch flags + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchFwk(const char *opName, void *args, uint32_t argsSize, uint32_t flags, + rtStream_t rtStream); + +/** + * @ingroup rt_kernel(abandoned) * @brief launch cpu kernel to device * @param [in] soName so name * @param [in] kernelName kernel name @@ -399,7 +417,22 @@ RTS_API rtError_t rtCpuKernelLaunch(const void *soName, const void *kernelName, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); /** - * @ingroup rt_kernel + * @ingroup rt_kernel(in use) + * @brief launch cpu kernel to device + * @param [in] launchNames names for kernel launch + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argments size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAicpuKernelLaunch(const rtKernelLaunchNames_t *launchNames, + uint32_t blockDim, const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); + +/** + * @ingroup rt_kernel(abandoned) * @brief launch cpu kernel to device with dump identifier * @param [in] soName so name * @param [in] kernelName kernel name @@ -416,6 +449,22 @@ RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kern const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); +/** + * @ingroup rt_kernel(in use) + * @brief launch cpu kernel to device with dump identifier + * @param [in] launchNames names for kernel launch + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argments size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @param [in] flag dump flag or others function flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAicpuKernelLaunchWithFlag(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim, + const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); + /** * @ingroup rt_kernel * @brief L1 fusion dump addr transfered to device diff --git a/third_party/fwkacllib/inc/runtime/mem.h b/third_party/fwkacllib/inc/runtime/mem.h index 30af85d9..bace4bc6 100644 --- a/third_party/fwkacllib/inc/runtime/mem.h +++ b/third_party/fwkacllib/inc/runtime/mem.h @@ -116,6 +116,9 @@ typedef enum tagRtMemInfoType { typedef enum tagRtRecudeKind { RT_MEMCPY_SDMA_AUTOMATIC_ADD = 10, // D2D, SDMA inline reduce, include 1P, and P2P + RT_MEMCPY_SDMA_AUTOMATIC_MAX = 11, + RT_MEMCPY_SDMA_AUTOMATIC_MIN = 12, + RT_MEMCPY_SDMA_AUTOMATIC_EQUAL = 13, RT_RECUDE_KIND_END } rtRecudeKind_t; @@ -123,6 +126,14 @@ typedef enum tagRtDataType { RT_DATA_TYPE_FP32 = 0, // fp32 RT_DATA_TYPE_FP16 = 1, // fp16 RT_DATA_TYPE_INT16 = 2, // int16 + RT_DATA_TYPE_INT4 = 3, // int4 + RT_DATA_TYPE_INT8 = 4, // int8 + RT_DATA_TYPE_INT32 = 5, // int32 + RT_DATA_TYPE_BFP16 = 6, // bfp16 + RT_DATA_TYPE_BFP32 = 7, // bfp32 + RT_DATA_TYPE_UINT8 = 8, // uint8 + RT_DATA_TYPE_UINT16= 9, // uint16 + RT_DATA_TYPE_UINT32= 10,// uint32 RT_DATA_TYPE_END } rtDataType_t; diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index 30b8f053..a7618b45 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -135,12 +135,13 @@ typedef struct tagAllKernelTaskInfo { uint16_t argsCount; uint16_t argsSize; uint16_t reserved; - const void *dev_func; + void *devfunc; void *handle; uint8_t *smDesc; uint8_t *args; uint16_t *argsOffset; } rtAllKernelTaskInfo_t; + typedef struct tagKernelTaskInfoEx { uint32_t flags; uint32_t argsSize; @@ -198,6 +199,13 @@ typedef struct tagProfilerTraceTaskInfo { uint32_t reserved[6]; } rtProfilerTrace_t; +typedef struct tagProfilerTraceExTaskInfo { + uint64_t profilerTraceId; + uint64_t modelId; + uint16_t tagId; + uint8_t reserved[22]; +} rtProfilerTraceEx_t; + typedef struct tagrtMemcpyAsyncTaskInfo { void *dst; uint64_t destMax; @@ -265,7 +273,7 @@ typedef struct tagTaskInfo { union { rtKernelTaskInfoEx_t kernelTaskEx; rtKernelTaskInfo_t kernelTask; - rtAllKernelTaskInfo_t allkernelTask; + rtAllKernelTaskInfo_t allKernelTask; rtEventTaskInfo_t eventTask; rtStreamSwitchTaskInfo_t streamSwitchTask; rtStreamActiveTaskInfo_t streamActiveTask; @@ -273,6 +281,7 @@ typedef struct tagTaskInfo { rtLabelSwitchTaskInfo_t labelSwitchTask; rtLabelGotoTaskInfo_t labelGotoTask; rtProfilerTrace_t profilertraceTask; + rtProfilerTraceEx_t profilertraceExTask; rtMemcpyAsyncTaskInfo_t memcpyAsyncTask; rtNotifyTaskInfo_t notifyTask; rtReduceAsyncTaskInfo_t reduceAsyncTask; diff --git a/third_party/fwkacllib/inc/toolchain/prof_callback.h b/third_party/fwkacllib/inc/toolchain/prof_callback.h index 18550157..5073cfb1 100644 --- a/third_party/fwkacllib/inc/toolchain/prof_callback.h +++ b/third_party/fwkacllib/inc/toolchain/prof_callback.h @@ -108,7 +108,19 @@ enum MsprofCtrlCallbackType { MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env MSPROF_CTRL_INIT_ACL_JSON, // start profiling with acl.json MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options - MSPROF_CTRL_FINALIZE // stop profiling + MSPROF_CTRL_FINALIZE, // stop profiling + MSPROF_CTRL_REPORT_FUN_P, // for report callback + MSPROF_CTRL_PROF_SWITCH_ON, // for prof switch on + MSPROF_CTRL_PROF_SWITCH_OFF // for prof switch off +}; + +#define MSPROF_MAX_DEV_NUM (64) + +struct MsprofCommandHandle { + uint64_t profSwitch; + uint32_t devNums; // length of device id list + uint32_t devIdList[MSPROF_MAX_DEV_NUM]; + uint32_t modelId; }; /** @@ -129,6 +141,23 @@ typedef int32_t (*MsprofCtrlCallback)(uint32_t type, void *data, uint32_t len); */ typedef void (*MsprofSetDeviceCallback)(uint32_t devId, bool isOpenDevice); +/* + * @name MsprofInit + * @brief Profiling module init + * @param [in] dataType: profiling type: ACL Env/ACL Json/GE Option + * @param [in] data: profiling switch data + * @param [in] dataLen: Length of data + * @return 0:SUCCESS, >0:FAILED + */ +int32_t MsprofInit(uint32_t dataType, void *data, uint32_t dataLen); + +/* + * @name AscendCL + * @brief Finishing Profiling + * @param NULL + * @return 0:SUCCESS, >0:FAILED + */ +int32_t MsprofFinalize(); #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/toolchain/slog.h b/third_party/fwkacllib/inc/toolchain/slog.h index ba286d02..cc7c83ca 100644 --- a/third_party/fwkacllib/inc/toolchain/slog.h +++ b/third_party/fwkacllib/inc/toolchain/slog.h @@ -17,6 +17,8 @@ #ifndef D_SYSLOG_H_ #define D_SYSLOG_H_ +static const int TMP_LOG = 0; + #ifdef __cplusplus #ifndef LOG_CPP extern "C" { @@ -120,15 +122,15 @@ typedef struct tagKV { } KeyValue; typedef enum { - APPLICATION = 0, - SYSTEM + APPLICATION = 0, + SYSTEM } ProcessType; typedef struct { - ProcessType type; - unsigned int pid; - unsigned int deviceId; - char reserved[RESERVERD_LENGTH]; + ProcessType type; + unsigned int pid; + unsigned int deviceId; + char reserved[RESERVERD_LENGTH]; } LogAttr; /** @@ -141,7 +143,7 @@ enum { IDEDD, /**< IDE daemon device */ IDEDH, /**< IDE daemon host */ HCCL, /**< HCCL */ - FMK, /**< Framework */ + FMK, /**< Adapter */ HIAIENGINE, /**< Matrix */ DVPP, /**< DVPP */ RUNTIME, /**< Runtime */ @@ -162,11 +164,11 @@ enum { MDCDEFAULT, /**< MDC undefine */ MDCSC, /**< MDC spatial cognition */ MDCPNC, - MLL, + MLL, /**< abandon */ DEVMM, /**< Dlog memory managent */ KERNEL, /**< Kernel */ LIBMEDIA, /**< Libmedia */ - CCECPU, /**< ai cpu */ + CCECPU, /**< aicpu shedule */ ASCENDDK, /**< AscendDK */ ROS, /**< ROS */ HCCP, @@ -179,7 +181,7 @@ enum { TSDUMP, /**< TSDUMP module */ AICPU, /**< AICPU module */ LP, /**< LP module */ - TDT, + TDT, /**< tsdaemon or aicpu shedule */ FE, MD, MB, @@ -261,7 +263,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); #define dlog_error(moduleId, fmt, ...) \ do { \ DlogErrorInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -276,7 +278,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); if(CheckLogLevel(moduleId, DLOG_WARN) == 1) { \ DlogWarnInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -291,7 +293,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); if(CheckLogLevel(moduleId, DLOG_INFO) == 1) { \ DlogInfoInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -306,7 +308,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); if(CheckLogLevel(moduleId, DLOG_DEBUG) == 1) { \ DlogDebugInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -318,7 +320,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); #define dlog_event(moduleId, fmt, ...) \ do { \ DlogEventInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -334,7 +336,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); if(CheckLogLevel(moduleId, level) == 1) { \ DlogInner(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -351,7 +353,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); if(CheckLogLevel(moduleId, level) == 1) { \ DlogInner(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -369,7 +371,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); if(CheckLogLevel(moduleId, level) == 1) { \ DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -381,13 +383,13 @@ DLL_EXPORT void DlogFlush(void); * @ingroup slog * @brief Internal log interface, other modules are not allowed to call this interface */ -void DlogErrorInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); -void DlogWarnInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); -void DlogInfoInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); -void DlogDebugInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); -void DlogEventInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); -void DlogInner(int moduleId, int level, const char *fmt, ...) __attribute__((format(printf, 3, 4))); -void DlogWithKVInner(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...) __attribute__((format(printf, 5, 6))); +void DlogErrorInner(int moduleId, const char *fmt, ...); +void DlogWarnInner(int moduleId, const char *fmt, ...); +void DlogInfoInner(int moduleId, const char *fmt, ...); +void DlogDebugInner(int moduleId, const char *fmt, ...); +void DlogEventInner(int moduleId, const char *fmt, ...); +void DlogInner(int moduleId, int level, const char *fmt, ...); +void DlogWithKVInner(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...); #ifdef __cplusplus #ifndef LOG_CPP @@ -453,7 +455,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr); if(CheckLogLevelForC(moduleId, level) == 1) { \ DlogInnerForC(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -470,7 +472,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr); if(CheckLogLevelForC(moduleId, level) == 1) { \ DlogInnerForC(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -488,7 +490,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr); if(CheckLogLevelForC(moduleId, level) == 1) { \ DlogWithKVInnerForC(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } \ - } while (0) + } while (TMP_LOG != 0) /** * @ingroup slog @@ -500,8 +502,8 @@ DLL_EXPORT void DlogFlushForC(void); * @ingroup slog * @brief Internal log interface, other modules are not allowed to call this interface */ -void DlogInnerForC(int moduleId, int level, const char *fmt, ...) __attribute__((format(printf, 3, 4))); -void DlogWithKVInnerForC(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...) __attribute__((format(printf, 5, 6))); +void DlogInnerForC(int moduleId, int level, const char *fmt, ...); +void DlogWithKVInnerForC(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...); #ifdef __cplusplus } diff --git a/third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h b/third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h index 6208f462..2cf6e0c4 100644 --- a/third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h +++ b/third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h @@ -1,72 +1,88 @@ -/** - * @file tune_api.h - * - * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.\n - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n - * 描述:mstune调优接口头文件 - */ -/** @defgroup mstune mstune调优接口 */ -#ifndef TUNE_API_H -#define TUNE_API_H -#include -#include -#include -#include "graph/graph.h" -#include "ge/ge_api.h" - -/** - * @ingroup mstune - * - * mstune status - */ -enum MsTuneStatus { - MSTUNE_SUCCESS, /** tune success */ - MSTUNE_FAILED, /** tune failed */ -}; - -// Option key: for train options sets -const std::string MSTUNE_SELF_KEY = "mstune"; -const std::string MSTUNE_GEINIT_KEY = "initialize"; -const std::string MSTUNE_GESESS_KEY = "session"; - -/** - * @ingroup mstune - * @par 描述: 命令行调优 - * - * @attention 无 - * @param option [IN] 调优参数 - * @param msg [OUT] 调优异常下返回信息 - * @retval #MSTUNE_SUCCESS 执行成功 - * @retval #MSTUNE_FAILED 执行失败 - * @par 依赖: - * @li tune_api.cpp:该接口所属的开发包。 - * @li tune_api.h:该接口声明所在的头文件。 - * @see 无 - * @since - */ -MsTuneStatus MsTuning(const std::map &option, std::string &msg); - -/** - * @ingroup mstune - * @par 描述: 梯度调优 - * - * @attention 无 - * @param tuningGraph [IN] 调优图 - * @param dependGraph [IN] 调优依赖图 - * @param session [IN] ge连接会话 - * @param option [IN] 参数集. 包含调优参数及ge参数 - * @retval #MSTUNE_SUCCESS 执行成功 - * @retval #MSTUNE_FAILED 执行失败 - * @par 依赖: - * @li tune_api.cpp:该接口所属的开发包。 - * @li tune_api.h:该接口声明所在的头文件。 - * @see 无 - * @since - */ -extern "C" MsTuneStatus MsTrainTuning(ge::Graph &tuningGraph, std::vector &dependGraph, - ge::Session *session, const std::map> &option); - -#endif +/** + * @file tune_api.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.\n + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n + * 描述:aoe调优接口头文件 + */ +/** @defgroup aoe aoe调优接口 */ +#ifndef TUNE_API_H +#define TUNE_API_H +#include +#include +#include "ge/ge_api.h" +#include "aoe_types.h" + +/** + * @ingroup aoe + * @par 描述: 命令行调优 + * + * @attention 无 + * @param option [IN] 调优参数 + * @param msg [OUT] 调优异常下返回信息 + * @retval #AOE_SUCCESS 执行成功 + * @retval #AOE_FAILURE 执行失败 + * @par 依赖: + * @li tune_api.cpp:该接口所属的开发包。 + * @li tune_api.h:该接口声明所在的头文件。 + * @see 无 + * @since + */ +AoeStatus AoeOfflineTuning(const std::map &option, std::string &msg); + +/** + * @ingroup aoe + * @par 描述: 调优初始化 + * + * @attention 无 + * @param session [IN] ge连接会话 + * @param option [IN] 参数集. 包含调优参数及ge参数 + * @retval #AOE_SUCCESS 执行成功 + * @retval #AOE_FAILURE 执行失败 + * @par 依赖: + * @li tune_api.cpp:该接口所属的开发包。 + * @li tune_api.h:该接口声明所在的头文件。 + * @see 无 + * @since + */ +extern "C" AoeStatus AoeOnlineInitialize(ge::Session *session, const std::map &option); + +/** + * @ingroup aoe + * @par 描述: 调优去初始化 + * + * @attention 无 + * @param 无 + * @retval #AOE_SUCCESS 执行成功 + * @retval #AOE_FAILURE 执行失败 + * @par 依赖: + * @li tune_api.cpp:该接口所属的开发包。 + * @li tune_api.h:该接口声明所在的头文件。 + * @see 无 + * @since + */ +extern "C" AoeStatus AoeOnlineFinalize(); + +/** + * @ingroup aoe + * @par 描述: 调优处理 + * + * @attention 无 + * @param tuningGraph [IN] 调优图 + * @param dependGraph [IN] 调优依赖图 + * @param session [IN] ge连接会话 + * @param option [IN] 参数集. 包含调优参数及ge参数 + * @retval #AOE_SUCCESS 执行成功 + * @retval #AOE_FAILURE 执行失败 + * @par 依赖: + * @li tune_api.cpp:该接口所属的开发包。 + * @li tune_api.h:该接口声明所在的头文件。 + * @see 无 + * @since + */ +extern "C" AoeStatus AoeOnlineTuning(ge::Graph &tuningGraph, std::vector &dependGraph, + ge::Session *session, const std::map &option); +#endif