| @@ -52,7 +52,6 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true | |||||
| ConstructorInitializerIndentWidth: 4 | ConstructorInitializerIndentWidth: 4 | ||||
| ContinuationIndentWidth: 4 | ContinuationIndentWidth: 4 | ||||
| Cpp11BracedListStyle: true | Cpp11BracedListStyle: true | ||||
| DerivePointerAlignment: true | |||||
| DisableFormat: false | DisableFormat: false | ||||
| ExperimentalAutoDetectBinPacking: false | ExperimentalAutoDetectBinPacking: false | ||||
| FixNamespaceComments: true | FixNamespaceComments: true | ||||
| @@ -94,7 +93,7 @@ PenaltyBreakString: 1000 | |||||
| PenaltyBreakTemplateDeclaration: 10 | PenaltyBreakTemplateDeclaration: 10 | ||||
| PenaltyExcessCharacter: 1000000 | PenaltyExcessCharacter: 1000000 | ||||
| PenaltyReturnTypeOnItsOwnLine: 200 | PenaltyReturnTypeOnItsOwnLine: 200 | ||||
| PointerAlignment: Left | |||||
| PointerAlignment: Right | |||||
| RawStringFormats: | RawStringFormats: | ||||
| - Language: Cpp | - Language: Cpp | ||||
| Delimiters: | Delimiters: | ||||
| @@ -95,6 +95,7 @@ target_link_libraries(ge_common PRIVATE | |||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>> | |||||
| static_mmpa | static_mmpa | ||||
| -Wl,--no-as-needed | -Wl,--no-as-needed | ||||
| graph | graph | ||||
| @@ -155,6 +156,7 @@ target_link_libraries(ge_common_static PRIVATE | |||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>> | |||||
| ascend_protobuf_static | ascend_protobuf_static | ||||
| json | json | ||||
| c_sec | c_sec | ||||
| @@ -163,7 +163,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDum | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] the dumpPath:%s, is invalid.", input.c_str()); | GELOGE(PARAM_INVALID, "[Check][Param] the dumpPath:%s, is invalid.", input.c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| if (mmAccess2(trusted_path, R_OK | W_OK) != EN_OK) { | |||||
| if (mmAccess2(trusted_path, M_R_OK | M_W_OK) != EN_OK) { | |||||
| REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | ||||
| std::vector<std::string>({ | std::vector<std::string>({ | ||||
| "ge.exec.dumpPath", | "ge.exec.dumpPath", | ||||
| @@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> &ex | |||||
| uint64_t proto_size = dump_data.ByteSizeLong(); | uint64_t proto_size = dump_data.ByteSizeLong(); | ||||
| std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | ||||
| GE_CHECK_NOTNULL(proto_msg); | |||||
| bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | ||||
| if (!ret || proto_size == 0) { | if (!ret || proto_size == 0) { | ||||
| REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); | REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); | ||||
| @@ -186,6 +186,8 @@ target_include_directories(ge_executor SYSTEM PRIVATE | |||||
| ${CMAKE_BINARY_DIR}/proto/graphengine_protos | ${CMAKE_BINARY_DIR}/proto/graphengine_protos | ||||
| #### yellow zone #### | #### yellow zone #### | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:${GE_DEPEND_DIR}/inc> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:${GE_DEPEND_DIR}/inc> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<TARGET_PROPERTY:runtime_headers,INTERFACE_INCLUDE_DIRECTORIES>> | |||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<TARGET_PROPERTY:cce_headers,INTERFACE_INCLUDE_DIRECTORIES>> | |||||
| #### blue zone #### | #### blue zone #### | ||||
| $<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc> | $<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc> | ||||
| $<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> | $<$<BOOL:${ENABLE_OPEN_SRC}>:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> | ||||
| @@ -251,6 +253,8 @@ target_link_libraries(ge_executor_shared PRIVATE | |||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:slog_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:msprof_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> | $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:mmpa_headers>> | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime_headers>> | |||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:cce_headers>> | |||||
| -Wl,--no-as-needed | -Wl,--no-as-needed | ||||
| ge_common | ge_common | ||||
| runtime | runtime | ||||
| @@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
| rt_ret = rtLabelListCpy(reinterpret_cast<void**>(label_list.data()), label_list.size(), label_info_, label_info_size); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | ||||
| return false; | return false; | ||||
| @@ -32,7 +32,6 @@ | |||||
| #include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "external/graph/ge_error_codes.h" | #include "external/graph/ge_error_codes.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/optimize/common/params.h" | #include "graph/optimize/common/params.h" | ||||
| #include "external/graph/types.h" | #include "external/graph/types.h" | ||||
| @@ -707,7 +706,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { | |||||
| if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { | if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { | ||||
| GE_CHECK_NOTNULL(kernel_buffer.GetData()); | GE_CHECK_NOTNULL(kernel_buffer.GetData()); | ||||
| std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | ||||
| tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); | |||||
| tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data)); | |||||
| GE_CHECK_NOTNULL(tbe_kernel); | GE_CHECK_NOTNULL(tbe_kernel); | ||||
| GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), | GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), | ||||
| node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | ||||
| @@ -793,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||||
| GELOGI("Start AutoFindBpOpIndex"); | GELOGI("Start AutoFindBpOpIndex"); | ||||
| NodePtr bp_node = nullptr; | NodePtr bp_node = nullptr; | ||||
| uint32_t current_idx = 0; | uint32_t current_idx = 0; | ||||
| uint32_t netoutput_idx = 0; | |||||
| for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | ||||
| OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| @@ -811,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||||
| if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { | if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { | ||||
| if (bp_node == nullptr) { | if (bp_node == nullptr) { | ||||
| bp_node = node; | bp_node = node; | ||||
| netoutput_idx = current_idx - 1; | |||||
| } | } | ||||
| } | } | ||||
| if (graph->GetNeedIteration()) { | if (graph->GetNeedIteration()) { | ||||
| @@ -836,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||||
| if (bp_node == nullptr) { | if (bp_node == nullptr) { | ||||
| GELOGW("not find bp_node."); | GELOGW("not find bp_node."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) { | |||||
| profiling_point.bp_index = netoutput_idx; | |||||
| GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx); | |||||
| } else { | |||||
| profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node); | |||||
| } | } | ||||
| return SUCCESS; | |||||
| return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index); | |||||
| } | } | ||||
| uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const { | |||||
| uint32_t last_bp = 0; | |||||
| Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node, | |||||
| uint32_t &bp_index) const { | |||||
| bp_index = 0; | |||||
| auto target_desc = target_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(target_desc); | |||||
| OpDescPtr bp_op_desc = nullptr; | OpDescPtr bp_op_desc = nullptr; | ||||
| for (auto &in_anchor : bp_node->GetAllInDataAnchors()) { | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(out_node_desc); | |||||
| if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) { | |||||
| bp_op_desc = out_node_desc; | |||||
| for (auto &in_node : target_node->GetInAllNodes()) { | |||||
| GE_CHECK_NOTNULL(in_node); | |||||
| auto in_node_desc = in_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(in_node_desc); | |||||
| if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) && | |||||
| (in_node_desc->GetStreamId() == target_desc->GetStreamId())){ | |||||
| bp_op_desc = in_node_desc; | |||||
| } | } | ||||
| GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId()); | |||||
| } | } | ||||
| if (bp_op_desc == nullptr) { | if (bp_op_desc == nullptr) { | ||||
| return last_bp; | |||||
| GELOGI("Did not find bp node."); | |||||
| return SUCCESS; | |||||
| } | } | ||||
| uint32_t current_idx = 0; | uint32_t current_idx = 0; | ||||
| for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | ||||
| @@ -871,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const | |||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| current_idx++; | current_idx++; | ||||
| if (op_desc->GetName() == bp_op_desc->GetName()) { | if (op_desc->GetName() == bp_op_desc->GetName()) { | ||||
| last_bp = current_idx; | |||||
| GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp); | |||||
| bp_index = current_idx; | |||||
| GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index); | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| return last_bp; | |||||
| GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(), | |||||
| bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId()); | |||||
| return SUCCESS; | |||||
| } | } | ||||
| Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | ||||
| @@ -116,7 +116,7 @@ class TaskGenerator { | |||||
| Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; | Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; | ||||
| Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | ||||
| vector<uint32_t> &all_reduce_nodes) const; | vector<uint32_t> &all_reduce_nodes) const; | ||||
| uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const; | |||||
| Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const; | |||||
| Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | ||||
| ProfilingPoint &profiling_point) const; | ProfilingPoint &profiling_point) const; | ||||
| @@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | ||||
| GE_CHECK_NOTNULL(args_addr); | |||||
| errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | ||||
| if (sec_ret != EOK) { | if (sec_ret != EOK) { | ||||
| REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); | REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); | ||||
| @@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| // copy args to new host memory | // copy args to new host memory | ||||
| args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | ||||
| GE_CHECK_NOTNULL(args_addr); | |||||
| GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | ||||
| errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | ||||
| if (sec_ret != EOK) { | if (sec_ret != EOK) { | ||||
| @@ -194,35 +194,6 @@ ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_na | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | |||||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||||
| GE_CHECK_NOTNULL(base_ptr); | |||||
| GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str()); | |||||
| VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; | |||||
| uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset; | |||||
| return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr, | |||||
| var_broadcast_info.input_size, session_id_); | |||||
| } | |||||
| ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||||
| GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str()); | |||||
| VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; | |||||
| // subgraph base_ptr could be nullptr, task it as base 0 | |||||
| uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset; | |||||
| return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name, | |||||
| var_tensor_desc, session_id_); | |||||
| } | |||||
| ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name, | |||||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||||
| return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); | |||||
| } | |||||
| bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } | bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } | ||||
| rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { | rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { | ||||
| @@ -638,16 +609,6 @@ bool VarManager::IsVarExist(const std::string &var_name) { | |||||
| return var_resource_->IsVarExist(var_name); | return var_resource_->IsVarExist(var_name); | ||||
| } | } | ||||
| ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||||
| uint8_t *base_ptr) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| if (var_resource_ == nullptr) { | |||||
| GELOGW("VarManager has not been init."); | |||||
| return ge::INTERNAL_ERROR; | |||||
| } | |||||
| return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr); | |||||
| } | |||||
| ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { | ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); | GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); | ||||
| @@ -701,16 +662,6 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt | |||||
| return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); | return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); | ||||
| } | } | ||||
| ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| if (var_resource_ == nullptr) { | |||||
| GELOGW("VarManager has not been init."); | |||||
| return ge::INTERNAL_ERROR; | |||||
| } | |||||
| return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr); | |||||
| } | |||||
| bool VarManager::IsVarAddr(const int64_t &offset) { | bool VarManager::IsVarAddr(const int64_t &offset) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| if (var_resource_ == nullptr) { | if (var_resource_ == nullptr) { | ||||
| @@ -118,15 +118,6 @@ class VarResource { | |||||
| ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | ||||
| ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | |||||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); | |||||
| ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, | |||||
| const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); | |||||
| ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||||
| uint8_t *base_ptr); | |||||
| Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { | Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { | ||||
| if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { | if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { | ||||
| GELOGW("Var name: %s has already set.", var_name.c_str()); | GELOGW("Var name: %s has already set.", var_name.c_str()); | ||||
| @@ -234,16 +225,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ||||
| ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||||
| uint8_t *base_ptr); | |||||
| ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ||||
| ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | ||||
| ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, | |||||
| uint8_t *base_ptr); | |||||
| ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); | ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); | ||||
| ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | ||||
| @@ -415,72 +415,6 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||||
| uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { | |||||
| GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr."); | |||||
| uint8_t *src_host_addr = nullptr; | |||||
| int64_t src_addr_size = 0; | |||||
| GE_MAKE_GUARD_RTMEM(src_host_addr); | |||||
| GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id)); | |||||
| GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size); | |||||
| GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, | |||||
| "[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld", | |||||
| src_addr_size, dst_addr_size); | |||||
| GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | |||||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | |||||
| GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. "); | |||||
| uint8_t *host_addr = nullptr; | |||||
| GE_MAKE_GUARD_RTMEM(host_addr); | |||||
| GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size)); | |||||
| GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
| GE_CHK_STATUS_RET( | |||||
| SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||||
| uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) { | |||||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed"); | |||||
| uint8_t *src_addr = nullptr; | |||||
| GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); | |||||
| uint8_t *mem_addr = | |||||
| src_addr - | |||||
| static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) + | |||||
| static_cast<int64_t>( | |||||
| reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); | |||||
| GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size)); | |||||
| GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
| GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, | |||||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | |||||
| uint8_t *dst_addr = nullptr; | |||||
| GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); | |||||
| uint8_t *mem_addr = | |||||
| dst_addr - | |||||
| static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) + | |||||
| static_cast<int64_t>( | |||||
| reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); | |||||
| GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
| GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | ||||
| uint64_t session_id, | uint64_t session_id, | ||||
| rtContext_t context, | rtContext_t context, | ||||
| @@ -29,11 +29,6 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class TransVarDataUtils { | class TransVarDataUtils { | ||||
| public: | public: | ||||
| static ge::Status SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||||
| uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id_); | |||||
| static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | |||||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | |||||
| static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes, | static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes, | ||||
| uint64_t session_id, | uint64_t session_id, | ||||
| rtContext_t context, | rtContext_t context, | ||||
| @@ -41,12 +36,6 @@ class TransVarDataUtils { | |||||
| uint32_t thread_num = 16); | uint32_t thread_num = 16); | ||||
| static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); | static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); | ||||
| private: | |||||
| static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | |||||
| uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); | |||||
| static ge::Status SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, | |||||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -301,12 +301,26 @@ graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, | |||||
| GeTensorPtr &output_ptr) { | GeTensorPtr &output_ptr) { | ||||
| std::vector<std::pair<int64_t, int64_t>> value_range; | std::vector<std::pair<int64_t, int64_t>> value_range; | ||||
| (void)tensor_desc.GetValueRange(value_range); | (void)tensor_desc.GetValueRange(value_range); | ||||
| if (static_cast<int64_t>(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) { | |||||
| GELOGW("Value range of input %s is invalid.", tensor_desc.GetName().c_str()); | |||||
| size_t value_range_data_num = value_range.size(); | |||||
| auto tensor_shape = tensor_desc.GetShape(); | |||||
| bool value_range_and_tensor_shape_matched = true; | |||||
| if (tensor_shape.IsScalar()){ | |||||
| // scalar tensor has only one value_range pair | |||||
| if (value_range_data_num != 1) { | |||||
| value_range_and_tensor_shape_matched = false; | |||||
| } | |||||
| } else { | |||||
| // normal tensor, value_range size is equal to tensor shape size. | |||||
| if (static_cast<int64_t>(value_range_data_num) != tensor_shape.GetShapeSize()) { | |||||
| value_range_and_tensor_shape_matched = false; | |||||
| } | |||||
| } | |||||
| if (!value_range_and_tensor_shape_matched) { | |||||
| GELOGW("Input %s value range and tensor shape do not match. Value range size is %zu, tensor shape is %s.", | |||||
| tensor_desc.GetName().c_str(), value_range_data_num, formats::ShapeToString(tensor_shape).c_str()); | |||||
| return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
| } | } | ||||
| size_t value_range_data_num = value_range.size(); | |||||
| unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]()); | unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]()); | ||||
| if (buf == nullptr) { | if (buf == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "New buf failed"); | REPORT_INNER_ERROR("E19999", "New buf failed"); | ||||
| @@ -494,10 +508,16 @@ void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, co | |||||
| GELOGI("Output tensor of cpu kernel does not have data, no way to set value range."); | GELOGI("Output tensor of cpu kernel does not have data, no way to set value range."); | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) { | |||||
| auto left_tensor_shape = left_tensor->GetTensorDesc().GetShape(); | |||||
| for (auto j = 0; j < left_tensor_shape.GetShapeSize(); ++j) { | |||||
| auto left = static_cast<int64_t>(*(x + j)); | auto left = static_cast<int64_t>(*(x + j)); | ||||
| auto right = static_cast<int64_t>(*(y + j)); | auto right = static_cast<int64_t>(*(y + j)); | ||||
| value_range.emplace_back(std::make_pair(left, right)); | |||||
| value_range.emplace_back(left, right); | |||||
| } | |||||
| if (left_tensor_shape.IsScalar()) { | |||||
| GELOGD("When inferring value range, output tensors of cpu kernel are scalar tensors."); | |||||
| value_range.emplace_back(static_cast<int64_t>(*x), static_cast<int64_t>(*y)); | |||||
| } | } | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std: | |||||
| } | } | ||||
| std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams()); | std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams()); | ||||
| GE_CHECK_NOTNULL(aipp_params); | |||||
| ge::GeAttrValue::NAMED_ATTRS aipp_attr; | ge::GeAttrValue::NAMED_ATTRS aipp_attr; | ||||
| GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, | GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, | ||||
| "[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); | "[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); | ||||
| @@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_ | |||||
| auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); | auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); | ||||
| if (!IsAllDimsPositive(dims)) { | if (!IsAllDimsPositive(dims)) { | ||||
| REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", | REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", | ||||
| node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | |||||
| node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", | GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", | ||||
| node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -295,13 +295,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData ¤t_data, Hy | |||||
| } | } | ||||
| } | } | ||||
| tensor_desc->SetShape(shape); | tensor_desc->SetShape(shape); | ||||
| args.input_desc[input_index] = tensor_desc; | |||||
| GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str()); | |||||
| GELOGD("Update shape[%s] of input[%zu] to [%s]", | |||||
| shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str()); | |||||
| GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), | GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), | ||||
| "[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," | "[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," | ||||
| "index = %zu, shape = [%s], model_id = %u.", | "index = %zu, shape = [%s], model_id = %u.", | ||||
| input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); | input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); | ||||
| GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size); | |||||
| GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size); | |||||
| TensorUtils::SetSize(*tensor_desc, tensor_size); | |||||
| args.input_desc[input_index] = tensor_desc; | |||||
| } | } | ||||
| GE_CHECK_GE(tensor_size, 0); | GE_CHECK_GE(tensor_size, 0); | ||||
| @@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, | |||||
| } | } | ||||
| HybridModelExecutor::~HybridModelExecutor() { | HybridModelExecutor::~HybridModelExecutor() { | ||||
| if (context_.rt_gen_context != nullptr) { | |||||
| (void) rtCtxDestroy(context_.rt_gen_context); | |||||
| } | |||||
| } | } | ||||
| Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
| @@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() { | |||||
| Status HybridModelExecutor::InitExecutionContext() { | Status HybridModelExecutor::InitExecutionContext() { | ||||
| GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | ||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
| context_.global_step = model_->GetGlobalStep(); | context_.global_step = model_->GetGlobalStep(); | ||||
| @@ -191,7 +191,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin | |||||
| } | } | ||||
| Status StageExecutor::InitExecutionContext() { | Status StageExecutor::InitExecutionContext() { | ||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
| context_.model = model_; | context_.model = model_; | ||||
| @@ -21,10 +21,17 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { | Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { | ||||
| const auto &node_item = *node_state.GetNodeItem(); | |||||
| GE_CHECK_NOTNULL(context); | GE_CHECK_NOTNULL(context); | ||||
| rtContext_t rt_gen_context = nullptr; | |||||
| GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| std::function<void()> callback = [&]() { | |||||
| (void) rtCtxDestroy(rt_gen_context); | |||||
| GE_CHK_RT(rtCtxSetCurrent(context->rt_context)); | |||||
| }; | |||||
| GE_MAKE_GUARD(rt_gen_context, callback); | |||||
| const auto &node_item = *node_state.GetNodeItem(); | |||||
| RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | ||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); | |||||
| if (context->ge_context != nullptr) { | if (context->ge_context != nullptr) { | ||||
| GetThreadLocalContext() = *context->ge_context; | GetThreadLocalContext() = *context->ge_context; | ||||
| @@ -1044,6 +1044,7 @@ Status HybridModelBuilder::InitConstantOps() { | |||||
| } else { | } else { | ||||
| var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); | var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(var_tensor); | |||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); | GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); | ||||
| GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); | GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); | ||||
| @@ -24,6 +24,8 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace { | namespace { | ||||
| const uint8_t kMaxTransCount = 3; | |||||
| const uint32_t kTransOpIoSize = 1; | |||||
| const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | ||||
| const char *const kNodeTypeRetVal = "_RetVal"; | const char *const kNodeTypeRetVal = "_RetVal"; | ||||
| const std::set<std::string> kControlOpTypes{ | const std::set<std::string> kControlOpTypes{ | ||||
| @@ -39,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{ | |||||
| MERGE, REFMERGE, STREAMMERGE | MERGE, REFMERGE, STREAMMERGE | ||||
| }; | }; | ||||
| bool IsEnterFeedNode(NodePtr node) { | |||||
| // For: Enter -> node | |||||
| // For: Enter -> Cast -> node | |||||
| // For: Enter -> TransData -> Cast -> node | |||||
| for (uint8_t i = 0; i < kMaxTransCount; ++i) { | |||||
| if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||||
| GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str()); | |||||
| return true; | |||||
| } | |||||
| const auto all_nodes = node->GetInDataNodes(); | |||||
| if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) { | |||||
| return false; | |||||
| } | |||||
| node = all_nodes.at(0); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | ||||
| uint32_t parent_index = 0; | uint32_t parent_index = 0; | ||||
| if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | ||||
| @@ -399,7 +420,7 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
| data_anchors.emplace(anchor_index); | data_anchors.emplace(anchor_index); | ||||
| } | } | ||||
| // If Enter feed Not Merge, take as root Node. | // If Enter feed Not Merge, take as root Node. | ||||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||||
| if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) { | |||||
| auto &data_anchors = node_item->enter_data_[this]; | auto &data_anchors = node_item->enter_data_[this]; | ||||
| data_anchors.emplace(anchor_index); | data_anchors.emplace(anchor_index); | ||||
| } | } | ||||
| @@ -419,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||||
| node_item->root_ctrl_.emplace(this); | node_item->root_ctrl_.emplace(this); | ||||
| } | } | ||||
| // If Enter feed control signal, take as root Node. | // If Enter feed control signal, take as root Node. | ||||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
| if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
| node_item->enter_ctrl_.emplace(this); | node_item->enter_ctrl_.emplace(this); | ||||
| } | } | ||||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
| @@ -50,6 +50,8 @@ const std::set<std::string> kBufferOptimizeSupportOption = {"l1_optimize", "l2_o | |||||
| const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; | const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; | ||||
| const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; | const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; | ||||
| const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision"; | const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision"; | ||||
| const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL = "high_precision_for_all"; | |||||
| const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL = "high_performance_for_all"; | |||||
| const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; | const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; | ||||
| const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; | const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; | ||||
| const char *const kSplitError1 = "size not equal to 2 split by \":\""; | const char *const kSplitError1 = "size not equal to 2 split by \":\""; | ||||
| @@ -57,7 +59,8 @@ const char *const kEmptyError = "can not be empty"; | |||||
| const char *const kFloatNumError = "exist float number"; | const char *const kFloatNumError = "exist float number"; | ||||
| const char *const kDigitError = "is not digit"; | const char *const kDigitError = "is not digit"; | ||||
| const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; | const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; | ||||
| const char *const kSelectImplmodeError = "only support high_performance, high_precision"; | |||||
| const char *const kSelectImplmodeError = "only support high_performance, high_precision, " | |||||
| "high_precision_for_all, high_performance_for_all"; | |||||
| const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; | const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; | ||||
| const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; | const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; | ||||
| const char *const kKeepDtypeError = "file not found"; | const char *const kKeepDtypeError = "file not found"; | ||||
| @@ -782,7 +785,9 @@ Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std:: | |||||
| op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT; | op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT; | ||||
| } else { | } else { | ||||
| if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT && | if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT && | ||||
| op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) { | |||||
| op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON && | |||||
| op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL && | |||||
| op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ||||
| {"--op_select_implmode", op_select_implmode.c_str(), | {"--op_select_implmode", op_select_implmode.c_str(), | ||||
| kSelectImplmodeError}); | kSelectImplmodeError}); | ||||
| @@ -143,7 +143,8 @@ DEFINE_string(output_type, "", | |||||
| DEFINE_string(op_select_implmode, "", | DEFINE_string(op_select_implmode, "", | ||||
| "Optional; op select implmode! " | "Optional; op select implmode! " | ||||
| "Support high_precision, high_performance."); | |||||
| "Support high_precision, high_performance, " | |||||
| "high_precision_for_all, high_performance_for_all."); | |||||
| DEFINE_string(optypelist_for_implmode, "", | DEFINE_string(optypelist_for_implmode, "", | ||||
| "Optional; Nodes need use implmode selected in op_select_implmode " | "Optional; Nodes need use implmode selected in op_select_implmode " | ||||
| @@ -311,8 +312,8 @@ class GFlagUtils { | |||||
| "scenarios by using a configuration file.\n" | "scenarios by using a configuration file.\n" | ||||
| " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" | " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" | ||||
| " --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n" | " --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n" | ||||
| " --op_select_implmode Set op select implmode. Support high_precision, high_performance. " | |||||
| "default: high_performance\n" | |||||
| " --op_select_implmode Set op select implmode. Support high_precision, high_performance, " | |||||
| "high_precision_for_all, high_performance_for_all. default: high_performance\n" | |||||
| " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n" | " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n" | ||||
| " Separate multiple nodes with commas (,). Use double quotation marks (\") " | " Separate multiple nodes with commas (,). Use double quotation marks (\") " | ||||
| "to enclose each argument. E.g.: \"node_name1,node_name2\"\n" | "to enclose each argument. E.g.: \"node_name1,node_name2\"\n" | ||||
| @@ -95,35 +95,6 @@ Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_ho | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | |||||
| bool is_infer_depend = false; | |||||
| bool is_host_mem = false; | |||||
| GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); | |||||
| bool need_d2h_cpy = is_infer_depend && !is_host_mem; | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | |||||
| int32_t kernel_task_num = 0; | |||||
| for (int i = 0; i < tasks.size(); ++i) { | |||||
| auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type()); | |||||
| if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { | |||||
| const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() : | |||||
| tasks[i].kernel_with_handle().context(); | |||||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | |||||
| if (kernel_type == ccKernelType::TE) { | |||||
| if (need_d2h_cpy) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| kernel_task_num++; | |||||
| if (kernel_task_num > 1) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size) | SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size) | ||||
| @@ -558,14 +529,15 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { | |||||
| return BuildTaskList(&resource, single_op); | return BuildTaskList(&resource, single_op); | ||||
| } | } | ||||
| Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, const TaskDef &task_def, | |||||
| DynamicSingleOp &single_op) { | |||||
| auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | |||||
| const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : | |||||
| task_def.kernel_with_handle().context(); | |||||
| Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) { | |||||
| auto ge_model = model_helper_.GetGeModel(); | |||||
| GE_CHECK_NOTNULL(ge_model); | |||||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | |||||
| if (kernel_type == ccKernelType::TE) { | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| single_op.compute_graph_ = compute_graph; | |||||
| if (tbe_tasks_.size() > 0) { | |||||
| const auto &task_def = tbe_tasks_[0]; | |||||
| GELOGD("Building TBE task."); | GELOGD("Building TBE task."); | ||||
| TbeOpTask *tbe_task = nullptr; | TbeOpTask *tbe_task = nullptr; | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task)); | GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task)); | ||||
| @@ -575,71 +547,81 @@ Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, cons | |||||
| tbe_task->stream_resource_ = stream_resource; | tbe_task->stream_resource_ = stream_resource; | ||||
| } | } | ||||
| single_op.op_task_.reset(tbe_task); | single_op.op_task_.reset(tbe_task); | ||||
| } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { | |||||
| GELOGD("Building AICPU_CC task"); | |||||
| OpTask *task = nullptr; | |||||
| uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; | |||||
| GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id); | |||||
| GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id)); | |||||
| task->SetModelArgs(model_name_, model_id_); | |||||
| single_op.op_task_.reset(task); | |||||
| } else { | |||||
| GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, | |||||
| "[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", | |||||
| context.kernel_type()); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.", | |||||
| context.kernel_type()); | |||||
| return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) { | |||||
| auto ge_model = model_helper_.GetGeModel(); | |||||
| GE_CHECK_NOTNULL(ge_model); | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| single_op.compute_graph_ = compute_graph; | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | |||||
| for (int i = 0; i < tasks.size(); ++i) { | |||||
| const TaskDef &task_def = tasks[i]; | |||||
| GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(), | |||||
| task_def.DebugString().c_str()); | |||||
| } else if (aicpu_tasks_.size() > 0) { | |||||
| const auto &task_def = aicpu_tasks_[0]; | |||||
| auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | ||||
| if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { | |||||
| if (single_op.op_task_ != nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks."); | |||||
| return ACL_ERROR_GE_OP_TASK_TYPE_INVALID; | |||||
| } | |||||
| GE_CHK_STATUS_RET_NOLOG(BuildModelTaskKernel(stream_resource, task_def, single_op)); | |||||
| if (task_type == RT_MODEL_TASK_KERNEL) { | |||||
| GELOGD("Building AICPU_CC task"); | |||||
| OpTask *task = nullptr; | |||||
| uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; | |||||
| GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id); | |||||
| GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id)); | |||||
| task->SetModelArgs(model_name_, model_id_); | |||||
| single_op.op_task_.reset(task); | |||||
| } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { | } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { | ||||
| if (single_op.op_task_ != nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks."); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks."); | |||||
| return ACL_ERROR_GE_OP_TASK_TYPE_INVALID; | |||||
| } | |||||
| GELOGD("Building AICPU_TF task"); | GELOGD("Building AICPU_TF task"); | ||||
| AiCpuTask *aicpu_task = nullptr; | AiCpuTask *aicpu_task = nullptr; | ||||
| uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; | uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; | ||||
| GELOGI("Build dynamic singleOp TfTask, kernel_id = %lu", dynamic_singleop_kernel_id); | GELOGI("Build dynamic singleOp TfTask, kernel_id = %lu", dynamic_singleop_kernel_id); | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, dynamic_singleop_kernel_id)); | GE_CHK_STATUS_RET_NOLOG(BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, dynamic_singleop_kernel_id)); | ||||
| if (aicpu_task->GetUnknownType() == DEPEND_COMPUTE) { | if (aicpu_task->GetUnknownType() == DEPEND_COMPUTE) { | ||||
| if (i >= tasks.size() - 1) { | |||||
| if (aicpu_tasks_.size() < 2) { | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Task]The copy task of the fourth operator was not found."); | GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Task]The copy task of the fourth operator was not found."); | ||||
| REPORT_INNER_ERROR("E19999", "The copy task of the fourth operator was not found."); | REPORT_INNER_ERROR("E19999", "The copy task of the fourth operator was not found."); | ||||
| return ACL_ERROR_GE_PARAM_INVALID; | return ACL_ERROR_GE_PARAM_INVALID; | ||||
| } | } | ||||
| ++i; | |||||
| const TaskDef ©_task_def = tasks[i]; | |||||
| const TaskDef ©_task_def = aicpu_tasks_[1]; | |||||
| GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex())); | GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex())); | ||||
| } | } | ||||
| aicpu_task->SetModelArgs(model_name_, model_id_); | aicpu_task->SetModelArgs(model_name_, model_id_); | ||||
| single_op.op_task_.reset(aicpu_task); | single_op.op_task_.reset(aicpu_task); | ||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SingleOpModel::NeedHybridModel(GeModelPtr &ge_model, bool &need_hybrid_model) { | |||||
| bool is_infer_depend = false; | |||||
| bool is_host_mem = false; | |||||
| GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); | |||||
| bool need_d2h_cpy = is_infer_depend && !is_host_mem; | |||||
| bool aicpu_multi_task = tbe_tasks_.size() >= 1 && aicpu_tasks_.size() >= 1; | |||||
| bool aicore_multi_task = tbe_tasks_.size() > 1; | |||||
| need_hybrid_model = need_d2h_cpy || aicore_multi_task || aicpu_multi_task; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SingleOpModel::ParseTasks() { | |||||
| auto ge_model = model_helper_.GetGeModel(); | |||||
| GE_CHECK_NOTNULL(ge_model); | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | |||||
| for (int i = 0; i < tasks.size(); ++i) { | |||||
| TaskDef &task_def = tasks[i]; | |||||
| GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(), | |||||
| task_def.DebugString().c_str()); | |||||
| auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | |||||
| if (task_type == RT_MODEL_TASK_KERNEL) { | |||||
| const auto &kernel_def = task_def.kernel(); | |||||
| const auto &context = kernel_def.context(); | |||||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | |||||
| if (kernel_type == ccKernelType::TE) { | |||||
| tbe_tasks_.emplace_back(task_def); | |||||
| } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { | |||||
| aicpu_tasks_.emplace_back(task_def); | |||||
| } else { | |||||
| GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, | |||||
| "[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", | |||||
| context.kernel_type()); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.", | |||||
| context.kernel_type()); | |||||
| return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; | |||||
| } | |||||
| } else if (task_type == RT_MODEL_TASK_ALL_KERNEL) { | |||||
| tbe_tasks_.emplace_back(task_def); | |||||
| } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { | |||||
| aicpu_tasks_.emplace_back(task_def); | |||||
| } else { | } else { | ||||
| // skip | // skip | ||||
| GELOGD("Skip task type: %d", static_cast<int>(task_type)); | GELOGD("Skip task type: %d", static_cast<int>(task_type)); | ||||
| @@ -654,6 +636,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & | |||||
| GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); | GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); | ||||
| model_params_.memory_size = UINT64_MAX; | model_params_.memory_size = UINT64_MAX; | ||||
| model_params_.graph_is_dynamic = true; | model_params_.graph_is_dynamic = true; | ||||
| GE_CHK_STATUS_RET(ParseTasks(), "[Parse][Tasks] failed."); | |||||
| auto ge_model = model_helper_.GetGeModel(); | auto ge_model = model_helper_.GetGeModel(); | ||||
| GE_CHECK_NOTNULL(ge_model); | GE_CHECK_NOTNULL(ge_model); | ||||
| @@ -71,13 +71,16 @@ class SingleOpModel { | |||||
| Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task); | Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task); | ||||
| Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id); | Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id); | ||||
| Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id); | Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id); | ||||
| Status BuildModelTaskKernel(StreamResource *stream_resource, const domi::TaskDef &task_def, | |||||
| DynamicSingleOp &single_op); | |||||
| static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); | static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); | ||||
| void ParseArgTable(OpTask *task, SingleOp &op); | void ParseArgTable(OpTask *task, SingleOp &op); | ||||
| Status InitHybridModelExecutor(const StreamResource &resource, const GeModelPtr &ge_model, SingleOp &single_op); | Status InitHybridModelExecutor(const StreamResource &resource, const GeModelPtr &ge_model, SingleOp &single_op); | ||||
| Status SetHostMemTensor(DynamicSingleOp &single_op); | Status SetHostMemTensor(DynamicSingleOp &single_op); | ||||
| Status NeedHybridModel(GeModelPtr &ge_model, bool &flag); | |||||
| Status ParseTasks(); | |||||
| std::vector<domi::TaskDef> tbe_tasks_; | |||||
| std::vector<domi::TaskDef> aicpu_tasks_; | |||||
| std::string model_name_; | std::string model_name_; | ||||
| uint32_t model_id_ = 0; | uint32_t model_id_ = 0; | ||||
| @@ -33,6 +33,10 @@ | |||||
| #include "register/op_tiling.h" | #include "register/op_tiling.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const int kAddressNum = 2; | |||||
| } // namespace | |||||
| class StreamResource; | class StreamResource; | ||||
| struct SingleOpModelParam; | struct SingleOpModelParam; | ||||
| class OpTask { | class OpTask { | ||||
| @@ -264,7 +268,7 @@ class MemcpyAsyncTask : public OpTask { | |||||
| friend class SingleOpModel; | friend class SingleOpModel; | ||||
| friend class RtsKernelTaskBuilder; | friend class RtsKernelTaskBuilder; | ||||
| uintptr_t addresses_[2]; | |||||
| uintptr_t addresses_[kAddressNum]; | |||||
| size_t dst_max_; | size_t dst_max_; | ||||
| size_t count_; | size_t count_; | ||||
| rtMemcpyKind_t kind_; | rtMemcpyKind_t kind_; | ||||
| @@ -104,7 +104,7 @@ Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bi | |||||
| binary.version = 0; | binary.version = 0; | ||||
| binary.data = kernel_bin.GetBinData(); | binary.data = kernel_bin.GetBinData(); | ||||
| binary.length = kernel_bin.GetBinDataSize(); | binary.length = kernel_bin.GetBinDataSize(); | ||||
| binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC; | |||||
| GE_CHK_STATUS_RET_NOLOG(GetMagic(binary.magic)); | |||||
| Status ret = 0; | Status ret = 0; | ||||
| if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) { | if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) { | ||||
| ret = rtRegisterAllKernel(&binary, bin_handle); | ret = rtRegisterAllKernel(&binary, bin_handle); | ||||
| @@ -416,4 +416,27 @@ Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) { | |||||
| task.EnableDynamicSupport(node_, tiling_buffer, static_cast<uint32_t>(max_size)); | task.EnableDynamicSupport(node_, tiling_buffer, static_cast<uint32_t>(max_size)); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TbeTaskBuilder::GetMagic(uint32_t &magic) const { | |||||
| std::string json_string; | |||||
| GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc_, TVM_ATTR_NAME_MAGIC, json_string), | |||||
| GELOGD("Get original type of session_graph_id.")); | |||||
| if (json_string == "RT_DEV_BINARY_MAGIC_ELF") { | |||||
| magic = RT_DEV_BINARY_MAGIC_ELF; | |||||
| } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") { | |||||
| magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; | |||||
| } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") { | |||||
| magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE; | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%s check invalid", | |||||
| TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(), | |||||
| op_desc_->GetType().c_str(), json_string.c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in op:%s(%s), value:%s check invalid", | |||||
| TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(), | |||||
| op_desc_->GetType().c_str(), json_string.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -105,6 +105,7 @@ class TbeTaskBuilder { | |||||
| const SingleOpModelParam ¶m); | const SingleOpModelParam ¶m); | ||||
| Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const; | Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const; | ||||
| Status DoRegisterMeta(void *bin_handle); | Status DoRegisterMeta(void *bin_handle); | ||||
| Status GetMagic(uint32_t &magic) const; | |||||
| static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); | static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit 9e4a51a9602195b82e326b853f5adbfefc3972b6 | |||||
| Subproject commit 9c9907b76a457f456072af96b8cbcfb7943beccc | |||||
| @@ -1 +1 @@ | |||||
| Subproject commit 79536a196f89cf7a1f5852ff7304b9a7d7b12eff | |||||
| Subproject commit 15a27afefe45f2abdb78787d629163aab9437599 | |||||
| @@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l | |||||
| ENV PROJECT_HOME=/code/Turing/graphEngine | ENV PROJECT_HOME=/code/Turing/graphEngine | ||||
| RUN mkdir /var/run/sshd | |||||
| RUN echo "root:root" | chpasswd | |||||
| RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config | |||||
| RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd | |||||
| ENV NOTVISIBLE "in users profile" | |||||
| RUN echo "export VISIBLE=now" >> /etc/profile | |||||
| EXPOSE 22 7777 | |||||
| RUN useradd -ms /bin/bash debugger | |||||
| RUN echo "debugger:ge123" | chpasswd | |||||
| CMD ["/usr/sbin/sshd" "-D" "&"] | |||||
| RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | ||||
| @@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
| DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | ||||
| DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | ||||
| DOCKER_IMAGE_TAG=ge_build_env.1.0.6 | |||||
| DOCKER_IMAGE_TAG=ge_build_env.1.0.9 | |||||
| DOCKER_IAMGE_NAME=joycode2art/turing | DOCKER_IAMGE_NAME=joycode2art/turing | ||||
| DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | ||||
| @@ -61,7 +61,7 @@ function enter_docker_env(){ | |||||
| if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | ||||
| echo "please run 'ge env --pull' to download images first!" | echo "please run 'ge env --pull' to download images first!" | ||||
| elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | ||||
| $docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||||
| $docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||||
| elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | ||||
| $docker_cmd start ${DOCKER_BUILD_ENV_NAME} | $docker_cmd start ${DOCKER_BUILD_ENV_NAME} | ||||
| $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | ||||
| @@ -60,6 +60,7 @@ set(SRCS | |||||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | ||||
| @@ -17,16 +17,3 @@ include(cmake/graphengine.cmake) | |||||
| add_subdirectory(easy_graph) | add_subdirectory(easy_graph) | ||||
| add_subdirectory(ge_graph_dsl) | add_subdirectory(ge_graph_dsl) | ||||
| add_subdirectory(ge_running_env) | add_subdirectory(ge_running_env) | ||||
| file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS | |||||
| "utils/*.cc" | |||||
| ) | |||||
| add_library(framework STATIC ${UTILS_SRC}) | |||||
| target_include_directories(framework | |||||
| PUBLIC utils/ | |||||
| ) | |||||
| set_target_properties(framework PROPERTIES CXX_STANDARD 11) | |||||
| target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env) | |||||
| @@ -26,16 +26,32 @@ EG_NS_BEGIN | |||||
| //////////////////////////////////////////////////////////////// | //////////////////////////////////////////////////////////////// | ||||
| namespace detail { | namespace detail { | ||||
| template<typename GRAPH_BUILDER> | |||||
| template <typename GRAPH_BUILDER> | |||||
| Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { | Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { | ||||
| GraphBuilder builder(name); | GraphBuilder builder(name); | ||||
| builderInDSL(builder); | builderInDSL(builder); | ||||
| return std::move(*builder); | return std::move(*builder); | ||||
| } | } | ||||
| struct GraphDefiner { | |||||
| GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) { | |||||
| name = specifiedName ? specifiedName : defaultName; | |||||
| } | |||||
| template <typename USER_BUILDER> | |||||
| auto operator|(USER_BUILDER &&userBuilder) { | |||||
| GraphBuilder graphBuilder{name}; | |||||
| std::forward<USER_BUILDER>(userBuilder)(graphBuilder); | |||||
| return *graphBuilder; | |||||
| } | |||||
| private: | |||||
| const char *name; | |||||
| }; | |||||
| } // namespace detail | } // namespace detail | ||||
| #define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__) | |||||
| #define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER) | |||||
| #define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER) | |||||
| #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ | #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ | ||||
| #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ | #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ | ||||
| #define CHAIN(...) DATA_CHAIN(__VA_ARGS__) | #define CHAIN(...) DATA_CHAIN(__VA_ARGS__) | ||||
| @@ -16,10 +16,15 @@ | |||||
| #include "easy_graph/layout/graph_layout.h" | #include "easy_graph/layout/graph_layout.h" | ||||
| #include "easy_graph/layout/layout_executor.h" | #include "easy_graph/layout/layout_executor.h" | ||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
| #include "easy_graph/graph/graph.h" | #include "easy_graph/graph/graph.h" | ||||
| EG_NS_BEGIN | EG_NS_BEGIN | ||||
| namespace { | |||||
| GraphEasyExecutor default_executor; | |||||
| } | |||||
| void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | ||||
| this->executor_ = &executor; | this->executor_ = &executor; | ||||
| options_ = opts; | options_ = opts; | ||||
| @@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | |||||
| Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { | Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { | ||||
| const LayoutOption *options = opts ? opts : this->options_; | const LayoutOption *options = opts ? opts : this->options_; | ||||
| if (!executor_) | |||||
| return EG_UNIMPLEMENTED; | |||||
| if (!executor_) return static_cast<LayoutExecutor &>(default_executor).Layout(graph, options); | |||||
| return executor_->Layout(graph, options); | return executor_->Layout(graph, options); | ||||
| } | } | ||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef D52AA06185E34BBFB714FFBCDAB0D53A | |||||
| #define D52AA06185E34BBFB714FFBCDAB0D53A | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include <exception> | |||||
| #include <string> | |||||
| GE_NS_BEGIN | |||||
| struct AssertError : std::exception { | |||||
| AssertError(const char *file, int line, const std::string &info); | |||||
| private: | |||||
| const char *what() const noexcept override; | |||||
| private: | |||||
| std::string info; | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_31309AA0A4E44C009C22AD9351BF3410 | |||||
| #define INC_31309AA0A4E44C009C22AD9351BF3410 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "graph/compute_graph.h" | |||||
| GE_NS_BEGIN | |||||
| using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||||
| struct CheckUtils { | |||||
| static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun); | |||||
| static void init(); | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -1,17 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tensor_builder_utils.h" | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef C8B32320BD4943D588594B82FFBF2685 | |||||
| #define C8B32320BD4943D588594B82FFBF2685 | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| GE_NS_BEGIN | |||||
| struct FilterScopeGuard { | |||||
| FilterScopeGuard(const std::vector<std::string> &); | |||||
| ~FilterScopeGuard(); | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||||
| #define AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "ge_graph_dsl/assert/check_utils.h" | |||||
| #include "ge_graph_dsl/assert/assert_error.h" | |||||
| #include "ge_graph_dsl/assert/filter_scope_guard.h" | |||||
| GE_NS_BEGIN | |||||
| #ifdef GTEST_MESSAGE_AT_ | |||||
| #define GRAPH_CHECK_MESSAGE(file, line, message) \ | |||||
| GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure) | |||||
| #elif | |||||
| #define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message) | |||||
| #endif | |||||
| namespace detail { | |||||
| struct GraphAssert { | |||||
| GraphAssert(const char *file, unsigned int line, const std::string &phase_id) | |||||
| : file_(file), line_(line), phase_id_(phase_id) {} | |||||
| void operator|(const ::GE_NS::GraphCheckFun &check_fun) { | |||||
| bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun); | |||||
| if (!ret) { | |||||
| auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! "; | |||||
| GRAPH_CHECK_MESSAGE(file_, line_, message.c_str()); | |||||
| } | |||||
| } | |||||
| private: | |||||
| const char *file_; | |||||
| unsigned int line_; | |||||
| const std::string phase_id_; | |||||
| }; | |||||
| } // namespace detail | |||||
| #define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__}); | |||||
| #define CHECK_GRAPH(phase_id) \ | |||||
| ::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph) | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -33,14 +33,12 @@ struct OpDescCfg { | |||||
| std::vector<int64_t> shape_; | std::vector<int64_t> shape_; | ||||
| }; | }; | ||||
| OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW, | |||||
| OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW, | |||||
| DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) | DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) | ||||
| : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} | : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} | ||||
| protected: | protected: | ||||
| OpType GetType() const { | |||||
| return type_; | |||||
| } | |||||
| OpType GetType() const { return type_; } | |||||
| OpType type_; | OpType type_; | ||||
| int in_cnt_; | int in_cnt_; | ||||
| int out_cnt_; | int out_cnt_; | ||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_dsl/assert/assert_error.h" | |||||
| GE_NS_BEGIN | |||||
| AssertError::AssertError(const char *file, int line, const std::string &info) { | |||||
| this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info; | |||||
| } | |||||
| const char *AssertError::what() const noexcept { return info.c_str(); } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_dsl/assert/check_utils.h" | |||||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||||
| #include "ge_graph_default_checker.h" | |||||
| #include "ge_graph_check_dumper.h" | |||||
| GE_NS_BEGIN | |||||
| bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) { | |||||
| auto &dumper = dynamic_cast<GeGraphCheckDumper &>(GraphDumperRegistry::GetDumper()); | |||||
| return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun)); | |||||
| } | |||||
| void CheckUtils::init() { | |||||
| static GeGraphCheckDumper checkDumper; | |||||
| GraphDumperRegistry::Register(checkDumper); | |||||
| } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_dsl/assert/filter_scope_guard.h" | |||||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||||
| #include "ge_dump_filter.h" | |||||
| GE_NS_BEGIN | |||||
| namespace { | |||||
| GeDumpFilter &GetDumpFilter() { return dynamic_cast<GeDumpFilter &>(GraphDumperRegistry::GetDumper()); } | |||||
| } // namespace | |||||
| FilterScopeGuard::FilterScopeGuard(const std::vector<std::string> &filter) { GetDumpFilter().Update(filter); } | |||||
| FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||||
| #define INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "easy_graph/infra/keywords.h" | |||||
| GE_NS_BEGIN | |||||
| INTERFACE(GeDumpFilter) { | |||||
| ABSTRACT(void Update(const std::vector<std::string> &)); | |||||
| ABSTRACT(void Reset()); | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_check_dumper.h" | |||||
| #include "graph/model.h" | |||||
| #include "graph/buffer.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "ge_graph_default_checker.h" | |||||
| GE_NS_BEGIN | |||||
| GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); } | |||||
| bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const { | |||||
| auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix); | |||||
| return (iter != suffixes_.end()); | |||||
| } | |||||
| void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) { | |||||
| if (!IsNeedDump(suffix)) { | |||||
| return; | |||||
| } | |||||
| auto iter = buffers_.find(suffix); | |||||
| if (iter != buffers_.end()) { | |||||
| DumpGraph(graph, iter->second); | |||||
| } else { | |||||
| buffers_[suffix] = Buffer(); | |||||
| DumpGraph(graph, buffers_.at(suffix)); | |||||
| } | |||||
| } | |||||
| bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) { | |||||
| auto iter = buffers_.find(checker.PhaseId()); | |||||
| if (iter == buffers_.end()) { | |||||
| return false; | |||||
| } | |||||
| DoCheck(checker, iter->second); | |||||
| return true; | |||||
| } | |||||
| void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) { | |||||
| Model model("", ""); | |||||
| Model::Load(buffer.GetData(), buffer.GetSize(), model); | |||||
| auto load_graph = model.GetGraph(); | |||||
| checker.Check(GraphUtils::GetComputeGraph(load_graph)); | |||||
| } | |||||
| void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) { | |||||
| Model model("", ""); | |||||
| buffer.clear(); | |||||
| model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||||
| model.Save(buffer, true); | |||||
| } | |||||
| void GeGraphCheckDumper::Update(const std::vector<std::string> &new_suffixes_) { | |||||
| suffixes_ = new_suffixes_; | |||||
| buffers_.clear(); | |||||
| } | |||||
| void GeGraphCheckDumper::Reset() { | |||||
| static std::vector<std::string> default_suffixes_{"PreRunAfterBuild"}; | |||||
| suffixes_ = default_suffixes_; | |||||
| buffers_.clear(); | |||||
| } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_8EFED0015C27464897BF64531355C810 | |||||
| #define INC_8EFED0015C27464897BF64531355C810 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||||
| #include "ge_dump_filter.h" | |||||
| #include <string> | |||||
| GE_NS_BEGIN | |||||
| struct GeGraphChecker; | |||||
| struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter { | |||||
| GeGraphCheckDumper(); | |||||
| virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix); | |||||
| bool CheckFor(const GeGraphChecker &checker); | |||||
| private: | |||||
| void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer); | |||||
| void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer); | |||||
| private: | |||||
| void Update(const std::vector<std::string> &) override; | |||||
| void Reset() override; | |||||
| bool IsNeedDump(const std::string &suffix) const; | |||||
| private: | |||||
| std::map<std::string, ::GE_NS::Buffer> buffers_; | |||||
| std::vector<std::string> suffixes_; | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_5960A8F437324904BEE0690271258762 | |||||
| #define INC_5960A8F437324904BEE0690271258762 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "easy_graph/infra/keywords.h" | |||||
| #include "graph/compute_graph.h" | |||||
| GE_NS_BEGIN | |||||
| INTERFACE(GeGraphChecker) { | |||||
| ABSTRACT(const std::string &PhaseId() const); | |||||
| ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const); | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_default_checker.h" | |||||
| GE_NS_BEGIN | |||||
| GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun) | |||||
| : phase_id_(phase_id), check_fun_(check_fun) {} | |||||
| const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; } | |||||
| void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef BCF4D96BE9FC48938DE7B7E93B551C54 | |||||
| #define BCF4D96BE9FC48938DE7B7E93B551C54 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "ge_graph_checker.h" | |||||
| #include "graph/compute_graph.h" | |||||
| GE_NS_BEGIN | |||||
| using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||||
| struct GeGraphDefaultChecker : GeGraphChecker { | |||||
| GeGraphDefaultChecker(const std::string &, const GraphCheckFun &); | |||||
| private: | |||||
| const std::string &PhaseId() const override; | |||||
| void Check(const ge::ComputeGraphPtr &graph) const override; | |||||
| private: | |||||
| const std::string phase_id_; | |||||
| const GraphCheckFun check_fun_; | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -23,15 +23,22 @@ GE_NS_BEGIN | |||||
| namespace { | namespace { | ||||
| #define OP_CFG(optype, ...) \ | |||||
| { \ | |||||
| optype, OpDescCfg { \ | |||||
| optype, __VA_ARGS__ \ | |||||
| } \ | |||||
| #define OP_CFG(optype, ...) \ | |||||
| { \ | |||||
| optype, OpDescCfg { optype, __VA_ARGS__ } \ | |||||
| } | } | ||||
| static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | ||||
| OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | ||||
| OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}), | |||||
| OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(VARIABLE, 1, 1)}; | OP_CFG(VARIABLE, 1, 1)}; | ||||
| } // namespace | } // namespace | ||||
| @@ -19,6 +19,4 @@ | |||||
| USING_GE_NS | USING_GE_NS | ||||
| OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { | |||||
| return op_; | |||||
| } | |||||
| OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; } | |||||
| @@ -36,17 +36,11 @@ GE_NS_BEGIN | |||||
| GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} | GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} | ||||
| void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { | |||||
| build_graph_ = graph; | |||||
| } | |||||
| void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; } | |||||
| Graph GeGraphVisitor::BuildGeGraph() const { | |||||
| return GraphUtils::CreateGraphFromComputeGraph(build_graph_); | |||||
| } | |||||
| Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); } | |||||
| ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { | |||||
| return build_graph_; | |||||
| } | |||||
| ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; } | |||||
| Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { | Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { | ||||
| build_graph_->SetName(graph.GetName()); | build_graph_->SetName(graph.GetName()); | ||||
| @@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE | |||||
| ) | ) | ||||
| set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) | set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) | ||||
| target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl) | |||||
| target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl) | |||||
| include(CTest) | include(CTest) | ||||
| enable_testing() | enable_testing() | ||||
| @@ -0,0 +1,129 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "gtest/gtest.h" | |||||
| #include "easy_graph/layout/graph_layout.h" | |||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
| #include "ge_graph_dsl/graph_dsl.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "ge_graph_dsl/assert/graph_assert.h" | |||||
| #include "graph/model.h" | |||||
| #include "graph/buffer.h" | |||||
| USING_GE_NS | |||||
| class CheckGraphTest : public testing::Test { | |||||
| private: | |||||
| EG_NS::GraphEasyExecutor executor; | |||||
| protected: | |||||
| void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DUMP_GRAPH_WHEN("after_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
| CHECK_GRAPH(after_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
| }; | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DEF_GRAPH(g2) { | |||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||||
| }; | |||||
| DUMP_GRAPH_WHEN("before_build", "after_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build"); | |||||
| CHECK_GRAPH(before_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
| }; | |||||
| CHECK_GRAPH(after_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g2"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||||
| }; | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DEF_GRAPH(g2) { | |||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||||
| }; | |||||
| DUMP_GRAPH_WHEN("before_build") | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build"); | |||||
| CHECK_GRAPH(before_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g2"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||||
| }; | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_check_phases_is_work) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DUMP_GRAPH_WHEN("before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
| auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {}); | |||||
| ASSERT_FALSE(ret); | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DUMP_GRAPH_WHEN("before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
| CHECK_GRAPH(before_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
| }; | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto ge_graph = ToGeGraph(g1); | |||||
| ge::Model model("", ""); | |||||
| model.SetGraph(ge_graph); | |||||
| Buffer buffer; | |||||
| model.Save(buffer, true); | |||||
| ge::Model loadModel("", ""); | |||||
| Model::Load(buffer.GetData(), buffer.GetSize(), loadModel); | |||||
| auto load_graph = loadModel.GetGraph(); | |||||
| ASSERT_EQ(load_graph.GetName(), "g1"); | |||||
| ASSERT_EQ(load_graph.GetAllNodes().size(), 2); | |||||
| } | |||||
| @@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test { | |||||
| EG_NS::GraphEasyExecutor executor; | EG_NS::GraphEasyExecutor executor; | ||||
| protected: | protected: | ||||
| void SetUp() { | |||||
| EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); | |||||
| } | |||||
| void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | ||||
| DEF_GRAPH(g1) { | |||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| }); | |||||
| DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| auto computeGraph = ToComputeGraph(g1); | auto computeGraph = ToComputeGraph(g1); | ||||
| @@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | |||||
| } | } | ||||
| TEST_F(GraphDslTest, test_build_graph_with_name) { | TEST_F(GraphDslTest, test_build_graph_with_name) { | ||||
| DEF_GRAPH(g1, "sample_graph") { | |||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| }); | |||||
| DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) { | |||||
| auto data = std::make_shared<OpDesc>("data1", DATA); | auto data = std::make_shared<OpDesc>("data1", DATA); | ||||
| auto add = std::make_shared<OpDesc>("Add", ADD); | auto add = std::make_shared<OpDesc>("Add", ADD); | ||||
| CHAIN(NODE(data)->NODE(add)); | CHAIN(NODE(data)->NODE(add)); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||||
| auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | ||||
| auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | ||||
| CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); | CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||||
| } | } | ||||
| TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | ||||
| DEF_GRAPH(g1) { | |||||
| CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); | |||||
| }); | |||||
| DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | |||||
| } | } | ||||
| TEST_F(GraphDslTest, test_build_from_control_chain) { | TEST_F(GraphDslTest, test_build_from_control_chain) { | ||||
| DEF_GRAPH(g1) { | |||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| }); | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) { | |||||
| } | } | ||||
| TEST_F(GraphDslTest, test_build_from_data_chain) { | TEST_F(GraphDslTest, test_build_from_data_chain) { | ||||
| DEF_GRAPH(g1) { | |||||
| DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| }); | |||||
| DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); | CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); | CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) { | |||||
| ->NODE("Add4") | ->NODE("Add4") | ||||
| ->NODE("Add5") | ->NODE("Add5") | ||||
| ->NODE("net_output", NETOUTPUT)); | ->NODE("net_output", NETOUTPUT)); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) { | |||||
| CHAIN(NODE("add2")->NODE("net_output")); | CHAIN(NODE("add2")->NODE("net_output")); | ||||
| CHAIN(NODE("add3")->NODE("net_output")); | CHAIN(NODE("add3")->NODE("net_output")); | ||||
| CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); | CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) { | |||||
| DEF_GRAPH(sub_1) { | DEF_GRAPH(sub_1) { | ||||
| CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); | ||||
| CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); | CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); | ||||
| }); | |||||
| }; | |||||
| DEF_GRAPH(sub_2) { | DEF_GRAPH(sub_2) { | ||||
| CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); | ||||
| CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); | CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); | ||||
| }); | |||||
| }; | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); | ||||
| CHAIN(NODE("data_i", DATA)->NODE("while")); | CHAIN(NODE("data_i", DATA)->NODE("while")); | ||||
| }); | |||||
| }; | |||||
| sub_1.Layout(); | sub_1.Layout(); | ||||
| sub_2.Layout(); | sub_2.Layout(); | ||||
| @@ -30,5 +30,11 @@ REGISTER_OPTYPE_DEFINE(MUL, "Mul"); | |||||
| REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); | REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); | ||||
| REGISTER_OPTYPE_DEFINE(ADD, "Add"); | REGISTER_OPTYPE_DEFINE(ADD, "Add"); | ||||
| REGISTER_OPTYPE_DEFINE(WHILE, "While"); | REGISTER_OPTYPE_DEFINE(WHILE, "While"); | ||||
| REGISTER_OPTYPE_DEFINE(ENTER, "Enter"); | |||||
| REGISTER_OPTYPE_DEFINE(MERGE, "Merge"); | |||||
| REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond"); | |||||
| REGISTER_OPTYPE_DEFINE(SWITCH, "Switch"); | |||||
| REGISTER_OPTYPE_DEFINE(EXIT, "Exit"); | |||||
| REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration"); | |||||
| GE_NS_END | GE_NS_END | ||||
| @@ -1,22 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| #define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| class tensor_builder_utils {}; | |||||
| #endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include "ge_graph_dsl/assert/check_utils.h" | |||||
| int main(int argc, char **argv) { | |||||
| ::GE_NS::CheckUtils::init(); | |||||
| testing::InitGoogleTest(&argc, argv); | |||||
| int ret = RUN_ALL_TESTS(); | |||||
| return ret; | |||||
| } | |||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph_builder_utils.h" | |||||
| #include "inc/external/graph/operator.h" | |||||
| #include "inc/external/graph/operator_factory.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| namespace ge { | |||||
| namespace st { | |||||
| NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
| Format format, DataType data_type, std::vector<int64_t> shape) { | |||||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
| tensor_desc->SetShape(GeShape(std::move(shape))); | |||||
| tensor_desc->SetFormat(format); | |||||
| tensor_desc->SetDataType(data_type); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < in_cnt; ++i) { | |||||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| for (int i = 0; i < out_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| return graph_->AddNode(op_desc); | |||||
| } | |||||
| void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { | |||||
| GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||||
| } | |||||
| void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||||
| GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||||
| } | |||||
| } // namespace st | |||||
| } // namespace ge | |||||
| @@ -1,55 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| #define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| namespace st { | |||||
| class ComputeGraphBuilder { | |||||
| public: | |||||
| explicit ComputeGraphBuilder(const std::string &name) { | |||||
| graph_ = std::make_shared<ComputeGraph>(name); | |||||
| } | |||||
| NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
| Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||||
| std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
| void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||||
| void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||||
| ComputeGraphPtr GetComputeGraph() { | |||||
| graph_->TopologicalSorting(); | |||||
| return graph_; | |||||
| } | |||||
| Graph GetGraph() { | |||||
| graph_->TopologicalSorting(); | |||||
| return GraphUtils::CreateGraphFromComputeGraph(graph_); | |||||
| } | |||||
| private: | |||||
| ComputeGraphPtr graph_; | |||||
| }; | |||||
| } // namespace st | |||||
| } // namespace ge | |||||
| #endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| @@ -8,7 +8,7 @@ target_include_directories(graph_engine_test | |||||
| set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | ||||
| target_link_libraries(graph_engine_test PRIVATE gtest framework) | |||||
| target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env) | |||||
| include(CTest) | include(CTest) | ||||
| enable_testing() | enable_testing() | ||||
| @@ -15,23 +15,12 @@ | |||||
| */ | */ | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include <map> | |||||
| #include "external/ge/ge_api.h" | #include "external/ge/ge_api.h" | ||||
| #include "ge_running_env/fake_engine.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "builder/graph_builder_utils.h" | |||||
| #include "ge_running_env/ge_running_env_faker.h" | #include "ge_running_env/ge_running_env_faker.h" | ||||
| #include "graph/operator_reg.h" | |||||
| #include "graph/operator.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "ge_graph_dsl/graph_dsl.h" | #include "ge_graph_dsl/graph_dsl.h" | ||||
| #undef protected | |||||
| #undef private | |||||
| #include "ge_graph_dsl/assert/graph_assert.h" | |||||
| using namespace std; | using namespace std; | ||||
| using namespace ge; | using namespace ge; | ||||
| @@ -57,76 +46,58 @@ namespace { | |||||
| * | * | ||||
| **/ | **/ | ||||
| Graph BuildV1ControlFlowGraph() { | Graph BuildV1ControlFlowGraph() { | ||||
| // build graph | |||||
| st::ComputeGraphBuilder graphBuilder("g1"); | |||||
| auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1); | |||||
| auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1); | |||||
| ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||||
| auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1); | |||||
| auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1); | |||||
| auto less = graphBuilder.AddNode("less", LESS, 2, 1); | |||||
| auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL); | |||||
| auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2); | |||||
| auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1); | |||||
| auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1); | |||||
| auto add = graphBuilder.AddNode("add", ADD, 2, 1); | |||||
| auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1); | |||||
| auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1); | |||||
| auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1); | |||||
| ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||||
| auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1); | |||||
| auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2); | |||||
| auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1); | |||||
| auto mul = graphBuilder.AddNode("mul", MUL, 2, 1); | |||||
| auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1); | |||||
| auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1); | |||||
| auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2); | |||||
| // i = i+1 | |||||
| graphBuilder.AddDataEdge(data_i, 0, enter_i, 0); | |||||
| graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0); | |||||
| graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1); | |||||
| graphBuilder.AddDataEdge(merge_i, 0, less, 0); | |||||
| graphBuilder.AddDataEdge(const_5, 0, less, 1); | |||||
| graphBuilder.AddDataEdge(less, 0, loopcond, 0); | |||||
| graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1); | |||||
| graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0); | |||||
| graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0); | |||||
| graphBuilder.AddDataEdge(switch_i, 1, add, 0); | |||||
| graphBuilder.AddDataEdge(const_1, 0, add, 1); | |||||
| graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0); | |||||
| graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1); | |||||
| // a=a*2 | |||||
| graphBuilder.AddDataEdge(data_a, 0, enter_a, 0); | |||||
| graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0); | |||||
| graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1); | |||||
| graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1); | |||||
| graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0); | |||||
| graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0); | |||||
| graphBuilder.AddDataEdge(switch_a, 1, mul, 0); | |||||
| graphBuilder.AddDataEdge(const_2, 0, mul, 1); | |||||
| graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0); | |||||
| graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0); | |||||
| // set const weight | |||||
| int64_t dims_size = 1; | int64_t dims_size = 1; | ||||
| vector<int64_t> data_vec = {5}; | vector<int64_t> data_vec = {5}; | ||||
| for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | ||||
| vector<int32_t> data_value_vec(dims_size, 1); | vector<int32_t> data_value_vec(dims_size, 1); | ||||
| GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | ||||
| GeTensorPtr data_tensor = | |||||
| make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); | |||||
| OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); | |||||
| OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); | |||||
| OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); | |||||
| GeTensorPtr data_tensor = make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), | |||||
| data_value_vec.size() * sizeof(int32_t)); | |||||
| return graphBuilder.GetGraph(); | |||||
| auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1"); | |||||
| auto const_op = OP_CFG(CONSTANT).Weight(data_tensor); | |||||
| DEF_GRAPH(g1) { | |||||
| CHAIN(NODE("data_i", DATA) | |||||
| ->NODE("enter_i", enter) | |||||
| ->EDGE(0, 0) | |||||
| ->NODE("merge_i", MERGE) | |||||
| ->NODE("less", LESS) | |||||
| ->NODE("loopcond", LOOPCOND)); | |||||
| CHAIN(NODE("const_1", const_op) | |||||
| ->EDGE(0, 1) | |||||
| ->NODE("add", ADD) | |||||
| ->NODE("iteration_i", NEXTITERATION) | |||||
| ->EDGE(0, 1) | |||||
| ->NODE("merge_i")); | |||||
| CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less")); | |||||
| CHAIN(NODE("loopcond") | |||||
| ->EDGE(0, 1) | |||||
| ->NODE("switch_i", SWITCH) | |||||
| ->EDGE(0, 0) | |||||
| ->NODE("exit_i", EXIT) | |||||
| ->EDGE(0, 1) | |||||
| ->NODE("netoutput", NETOUTPUT)); | |||||
| CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add")); | |||||
| CHAIN(NODE("data_a", DATA) | |||||
| ->NODE("enter_a", enter) | |||||
| ->NODE("merge_a", MERGE) | |||||
| ->NODE("switch_a", SWITCH) | |||||
| ->NODE("exit_a", EXIT) | |||||
| ->EDGE(0, 0) | |||||
| ->NODE("netoutput")); | |||||
| CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a")); | |||||
| CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL)); | |||||
| CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a")); | |||||
| }; | |||||
| return ToGeGraph(g1); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| class FrameworkTest : public testing::Test { | class FrameworkTest : public testing::Test { | ||||
| protected: | protected: | ||||
| GeRunningEnvFaker ge_env; | |||||
| void SetUp() { ge_env.InstallDefault(); } | void SetUp() { ge_env.InstallDefault(); } | ||||
| void TearDown() {} | void TearDown() {} | ||||
| GeRunningEnvFaker ge_env; | |||||
| }; | }; | ||||
| /// data data | /// data data | ||||
| @@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto graph = ToGeGraph(g1); | |||||
| // new session & add graph | |||||
| map<AscendString, AscendString> options; | map<AscendString, AscendString> options; | ||||
| Session session(options); | Session session(options); | ||||
| auto ret = session.AddGraph(1, graph, options); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| // build input tensor | |||||
| session.AddGraph(1, ToGeGraph(g1), options); | |||||
| std::vector<InputTensorInfo> inputs; | std::vector<InputTensorInfo> inputs; | ||||
| // build_graph through session | |||||
| ret = session.BuildGraph(1, inputs); | |||||
| auto ret = session.BuildGraph(1, inputs); | |||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| CHECK_GRAPH(PreRunAfterBuild) { | |||||
| ASSERT_EQ(graph->GetName(), "g1_1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 4); | |||||
| }; | |||||
| } | } | ||||
| /** data a = 2; | /** data a = 2; | ||||
| @@ -15,24 +15,12 @@ | |||||
| */ | */ | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include "easy_graph/graph/box.h" | |||||
| #include "easy_graph/graph/node.h" | |||||
| #include "external/ge/ge_api.h" | |||||
| #include "easy_graph/builder/graph_dsl.h" | #include "easy_graph/builder/graph_dsl.h" | ||||
| #include "easy_graph/builder/box_builder.h" | |||||
| #include "easy_graph/layout/graph_layout.h" | |||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/ge_local_context.h" | |||||
| #include "ge_graph_dsl/graph_dsl.h" | #include "ge_graph_dsl/graph_dsl.h" | ||||
| #include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "ge_opt_info/ge_opt_info.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace ge { | namespace ge { | ||||
| class STEST_opt_info : public testing::Test { | class STEST_opt_info : public testing::Test { | ||||
| @@ -52,7 +40,7 @@ TEST_F(STEST_opt_info, get_opt_info_all) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto graph = ToGeGraph(g1); | auto graph = ToGeGraph(g1); | ||||
| @@ -95,7 +83,7 @@ TEST_F(STEST_opt_info, get_opt_info_success) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto graph = ToGeGraph(g1); | auto graph = ToGeGraph(g1); | ||||
| @@ -15,9 +15,8 @@ | |||||
| */ | */ | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include "common/debug/log.h" | |||||
| #include "external/ge/ge_api.h" | #include "external/ge/ge_api.h" | ||||
| #include "ge_graph_dsl/assert/check_utils.h" | |||||
| #include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | #include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | ||||
| using namespace std; | using namespace std; | ||||
| @@ -31,6 +30,7 @@ int main(int argc, char **argv) { | |||||
| std::cout << "ge init failed , ret code:" << init_status << endl; | std::cout << "ge init failed , ret code:" << init_status << endl; | ||||
| } | } | ||||
| GeRunningEnvFaker::BackupEnv(); | GeRunningEnvFaker::BackupEnv(); | ||||
| CheckUtils::init(); | |||||
| testing::InitGoogleTest(&argc, argv); | testing::InitGoogleTest(&argc, argv); | ||||
| int ret = RUN_ALL_TESTS(); | int ret = RUN_ALL_TESTS(); | ||||
| return ret; | return ret; | ||||
| @@ -90,6 +90,7 @@ set(SRC_FILES | |||||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | ||||
| @@ -102,6 +102,7 @@ set(GRAPH_SRC_FILES | |||||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
| @@ -116,7 +116,9 @@ TEST_F(UtestTaskGeneratorTest, FindLastBpFromBpNode) { | |||||
| TaskGenerator task_generator(nullptr, 0); | TaskGenerator task_generator(nullptr, 0); | ||||
| auto net_output = graph->FindNode("Node_Output"); | auto net_output = graph->FindNode("Node_Output"); | ||||
| // netoutput has no data input, return default value 0 | // netoutput has no data input, return default value 0 | ||||
| EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output), 0); | |||||
| uint32_t bp_index = 0; | |||||
| EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output, bp_index), 0); | |||||
| EXPECT_EQ(bp_index, 2); | |||||
| } | } | ||||
| TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) { | TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) { | ||||
| @@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) { | |||||
| AddNPass *addn_pass = nullptr; | AddNPass *addn_pass = nullptr; | ||||
| NamesToPass names_to_pass; | NamesToPass names_to_pass; | ||||
| names_to_pass.emplace_back("Test", addn_pass); | names_to_pass.emplace_back("Test", addn_pass); | ||||
| EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_NE(pass.Run(names_to_pass), SUCCESS); | |||||
| } | } | ||||
| TEST(UtestGraphPassesAddnPass, null_graph) { | TEST(UtestGraphPassesAddnPass, null_graph) { | ||||
| @@ -67,6 +67,22 @@ class UtestTestPass : public BaseNodePass { | |||||
| names_to_add_repass_.erase(iter); | names_to_add_repass_.erase(iter); | ||||
| } | } | ||||
| } | } | ||||
| iter = names_to_add_repass_immediate_.find(node->GetName()); | |||||
| if (iter != names_to_add_repass_immediate_.end()) { | |||||
| auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); | |||||
| for (const auto &node_name : iter->second) { | |||||
| for (auto &node_re_pass : all_nodes) { | |||||
| if (node_re_pass->GetName() == node_name) { | |||||
| AddImmediateRePassNode(node_re_pass); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!dead_loop_) { | |||||
| names_to_add_repass_.erase(iter); | |||||
| } | |||||
| } | |||||
| // simulate infershape pass | // simulate infershape pass | ||||
| if(node->GetType() == WHILE){ | if(node->GetType() == WHILE){ | ||||
| bool need_repass = false; | bool need_repass = false; | ||||
| @@ -94,12 +110,17 @@ class UtestTestPass : public BaseNodePass { | |||||
| void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { | void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { | ||||
| names_to_add_del_[iter_node].insert(del_node); | names_to_add_del_[iter_node].insert(del_node); | ||||
| } | } | ||||
| void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) { | |||||
| names_to_add_repass_immediate_[iter_node].insert(re_pass_node); | |||||
| } | |||||
| unsigned int GetRunTimes() { return run_times_; } | unsigned int GetRunTimes() { return run_times_; } | ||||
| private: | private: | ||||
| std::vector<NodePtr> iter_nodes_; | std::vector<NodePtr> iter_nodes_; | ||||
| std::map<std::string, std::unordered_set<std::string>> names_to_add_del_; | std::map<std::string, std::unordered_set<std::string>> names_to_add_del_; | ||||
| std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_; | std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_; | ||||
| std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_immediate_; | |||||
| bool dead_loop_; | bool dead_loop_; | ||||
| unsigned int run_times_; | unsigned int run_times_; | ||||
| }; | }; | ||||
| @@ -520,4 +541,98 @@ EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // repass pre_node immediately | |||||
| test_pass.AddRePassImmediateNodeName("reshape1", "add1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "add1", "sum1"}); | |||||
| CheckIterOrder(&test_pass, layers); | |||||
| } | |||||
| /// sum1 | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// / \. | |||||
| /// reshape1 addn1 | |||||
| /// | c | | |||||
| /// add1 <--- shape1 | |||||
| /// / \ | | |||||
| /// | | | | |||||
| /// data1 const1 const2 | |||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // repass cur_node immediately | |||||
| test_pass.AddRePassImmediateNodeName("reshape1", "reshape1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_EQ(test_pass.GetIterNodes().size(), 9);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(&test_pass, layers); | |||||
| } | |||||
| TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // repass next_node immediately | |||||
| test_pass.AddRePassImmediateNodeName("reshape1", "sum1"); | |||||
| // repass node after next_node immediately | |||||
| test_pass.AddRePassImmediateNodeName("add1", "sum1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(&test_pass, layers); | |||||
| } | |||||
| /* | |||||
| TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) { | |||||
| NamesToPass names_to_pass; | |||||
| auto test_pass = UtestTestPass(); | |||||
| names_to_pass.push_back(std::make_pair("test", &test_pass)); | |||||
| // repass next_node immediately | |||||
| test_pass.AddRePassNodeName("reshape1", "sum1"); | |||||
| // repass node after next_node immediately | |||||
| test_pass.AddRePassNodeName("add1", "sum1"); | |||||
| auto graph = BuildGraph2(); | |||||
| auto ge_pass = GEPass(graph); | |||||
| EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); | |||||
| EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo | |||||
| std::vector<std::unordered_set<std::string>> layers; | |||||
| layers.push_back({"data1", "const1", "const2"}); | |||||
| layers.push_back({"shape1"}); | |||||
| layers.push_back({"add1", "addn1"}); | |||||
| layers.push_back({"reshape1", "sum1"}); | |||||
| CheckIterOrder(&test_pass, layers); | |||||
| }*/ | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -293,6 +293,9 @@ class AddKernel : public Kernel { | |||||
| } else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) { | } else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) { | ||||
| vector<int32_t> data_vec; | vector<int32_t> data_vec; | ||||
| auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); | auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); | ||||
| if (input[0]->GetTensorDesc().GetShape().IsScalar()) { | |||||
| data_num = 1; | |||||
| } | |||||
| auto x1_data = reinterpret_cast<const int32_t *>(input[0]->GetData().data()); | auto x1_data = reinterpret_cast<const int32_t *>(input[0]->GetData().data()); | ||||
| auto x2_data = reinterpret_cast<const int32_t *>(input[1]->GetData().data()); | auto x2_data = reinterpret_cast<const int32_t *>(input[1]->GetData().data()); | ||||
| for (size_t i = 0; i < data_num; i++) { | for (size_t i = 0; i < data_num; i++) { | ||||
| @@ -410,6 +413,48 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave | |||||
| EXPECT_EQ(unknown_target_value_range, output_value_range); | EXPECT_EQ(unknown_target_value_range, output_value_range); | ||||
| } | } | ||||
| TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_ScalarOutput) { | |||||
| // shape --- add --- sqrt | |||||
| // constant / | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| vector<int32_t> data_vec = {2}; | |||||
| GeTensorDesc const_td(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| GeTensorPtr const_tensor = std::make_shared<ge::GeTensor>(const_td, (uint8_t *)data_vec.data(), sizeof(int32_t)); | |||||
| auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant"); | |||||
| const_op_desc->AddOutputDesc(const_td); | |||||
| EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); | |||||
| auto const_node = graph->AddNode(const_op_desc); | |||||
| GeTensorDesc shape_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| std::vector<std::pair<int64_t, int64_t>> known_value_range = {make_pair(1, 100)}; | |||||
| shape_td.SetValueRange(known_value_range); | |||||
| auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
| shape_op_desc->AddOutputDesc(shape_td); | |||||
| auto shape_node = graph->AddNode(shape_op_desc); | |||||
| GeTensorDesc add_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| auto add_op_desc = std::make_shared<OpDesc>("Add", "Add"); | |||||
| add_op_desc->AddInputDesc(shape_td); | |||||
| add_op_desc->AddInputDesc(const_td); | |||||
| add_op_desc->AddOutputDesc(add_td); | |||||
| auto add_node = graph->AddNode(add_op_desc); | |||||
| ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); | |||||
| ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); | |||||
| InferValueRangePass infer_pass; | |||||
| EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); | |||||
| auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); | |||||
| std::vector<std::pair<int64_t, int64_t>> out_value_range; | |||||
| output_0_desc.GetValueRange(out_value_range); | |||||
| EXPECT_EQ(out_value_range.size(), 1); | |||||
| std::vector<int64_t> target_value_range = {3, 102}; | |||||
| std::vector<int64_t> output_value_range = {out_value_range[0].first, out_value_range[0].second}; | |||||
| EXPECT_EQ(output_value_range, target_value_range); | |||||
| } | |||||
| TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { | TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { | ||||
| // shape --- add --- sqrt | // shape --- add --- sqrt | ||||
| // constant / | // constant / | ||||
| @@ -103,4 +103,32 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute) { | |||||
| context.callback_manager->callback_queue_.Push(eof_entry); | context.callback_manager->callback_queue_.Push(eof_entry); | ||||
| ASSERT_EQ(executor.Execute(args), SUCCESS); | ASSERT_EQ(executor.Execute(args), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestHybridModelAsyncExecutor, test_PrepareInputs) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| HybridModelAsyncExecutor executor(&hybrid_model); | |||||
| GeTensorDescPtr tensor_desc = make_shared<GeTensorDesc>(GeShape({-1, 16, 16, 3})); | |||||
| tensor_desc->SetShapeRange({{1, 256}, {16, 16}, {16, 16}, {3, 3}}); | |||||
| executor.input_tensor_desc_.insert({0, tensor_desc}); | |||||
| executor.device_id_ = 0; | |||||
| executor.input_sizes_.insert({0, -1}); | |||||
| executor.is_input_dynamic_.push_back(true); | |||||
| unique_ptr<uint8_t[]> data_buf(new (std::nothrow)uint8_t[3072]); | |||||
| InputData input_data; | |||||
| input_data.blobs.push_back(DataBuffer(data_buf.get(), 3072, false)); | |||||
| input_data.shapes.push_back({1, 16, 16, 3}); | |||||
| HybridModelExecutor::ExecuteArgs args; | |||||
| auto ret = executor.PrepareInputs(input_data, args); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| ASSERT_EQ(args.input_desc[0]->GetShape().ToString(), GeShape({1, 16, 16, 3}).ToString()); | |||||
| int64_t tensor_size = 0; | |||||
| TensorUtils::GetSize(*(args.input_desc[0]), tensor_size); | |||||
| ASSERT_EQ(tensor_size, 3104); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "hybrid/executor/hybrid_model_executor.h" | #include "hybrid/executor/hybrid_model_executor.h" | ||||
| #include "hybrid/executor/worker/execution_engine.h" | #include "hybrid/executor/worker/execution_engine.h" | ||||
| #include "hybrid/executor/subgraph_executor.h" | #include "hybrid/executor/subgraph_executor.h" | ||||
| #include "hybrid/executor/worker/task_compile_engine.h" | |||||
| #undef private | #undef private | ||||
| #undef protected | #undef protected | ||||
| @@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test { | |||||
| }; | }; | ||||
| namespace { | namespace { | ||||
| const int kIntBase = 10; | const int kIntBase = 10; | ||||
| class CompileNodeExecutor : public NodeExecutor { | |||||
| public: | |||||
| Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr<NodeTask> &task) const override { | |||||
| return SUCCESS; | |||||
| } | |||||
| }; | |||||
| } | } | ||||
| static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { | static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { | ||||
| auto op_desc = std::make_shared<ge::OpDesc>(name, type); | auto op_desc = std::make_shared<ge::OpDesc>(name, type); | ||||
| op_desc->SetStreamId(0); | op_desc->SetStreamId(0); | ||||
| @@ -128,4 +136,8 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
| executor.InitCallback(node_state.get(), callback); | executor.InitCallback(node_state.get(), callback); | ||||
| ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
| EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | ||||
| CompileNodeExecutor node_executor; | |||||
| node_item->node_executor = &node_executor; | |||||
| EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS); | |||||
| } | } | ||||
| @@ -153,6 +153,7 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { | |||||
| ge::AttrUtils::SetStr(op_desc, "compile_info_json", "json"); | ge::AttrUtils::SetStr(op_desc, "compile_info_json", "json"); | ||||
| ge::AttrUtils::SetBool(op_desc, "support_dynamicshape", true); | ge::AttrUtils::SetBool(op_desc, "support_dynamicshape", true); | ||||
| ge::AttrUtils::SetInt(op_desc, "op_para_size", 1); | ge::AttrUtils::SetInt(op_desc, "op_para_size", 1); | ||||
| ge::AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF"); | |||||
| auto node = graph->AddNode(op_desc); | auto node = graph->AddNode(op_desc); | ||||
| std::unique_ptr<NodeItem> node_item; | std::unique_ptr<NodeItem> node_item; | ||||
| @@ -87,6 +87,7 @@ TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) { | |||||
| TEST_F(NodeExecutorTest, TestInitAndFinalize) { | TEST_F(NodeExecutorTest, TestInitAndFinalize) { | ||||
| auto &manager = NodeExecutorManager::GetInstance(); | auto &manager = NodeExecutorManager::GetInstance(); | ||||
| manager.FinalizeExecutors(); | manager.FinalizeExecutors(); | ||||
| manager.FinalizeExecutors(); | |||||
| manager.EnsureInitialized(); | manager.EnsureInitialized(); | ||||
| manager.EnsureInitialized(); | manager.EnsureInitialized(); | ||||
| const NodeExecutor *executor = nullptr; | const NodeExecutor *executor = nullptr; | ||||
| @@ -311,7 +311,7 @@ TEST_F(UtestSingleOpModel, BuildTaskList) { | |||||
| ASSERT_EQ(mem_task.LaunchKernel(0), SUCCESS); | ASSERT_EQ(mem_task.LaunchKernel(0), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestSingleOpModel, build_aicpu_task) { | |||||
| TEST_F(UtestSingleOpModel, build_dynamic_task) { | |||||
| ComputeGraphPtr graph = make_shared<ComputeGraph>("single_op"); | ComputeGraphPtr graph = make_shared<ComputeGraph>("single_op"); | ||||
| GeModelPtr ge_model = make_shared<GeModel>(); | GeModelPtr ge_model = make_shared<GeModel>(); | ||||
| ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | ||||
| @@ -321,6 +321,15 @@ TEST_F(UtestSingleOpModel, build_aicpu_task) { | |||||
| domi::TaskDef *task_def = model_task_def->add_task(); | domi::TaskDef *task_def = model_task_def->add_task(); | ||||
| task_def->set_type(RT_MODEL_TASK_KERNEL_EX); | task_def->set_type(RT_MODEL_TASK_KERNEL_EX); | ||||
| domi::TaskDef *task_def2 = model_task_def->add_task(); | |||||
| task_def2->set_type(RT_MODEL_TASK_KERNEL); | |||||
| domi::KernelDef *kernel_def = task_def2->mutable_kernel(); | |||||
| domi::KernelContext *context = kernel_def->mutable_context(); | |||||
| context->set_kernel_type(6); // ccKernelType::AI_CPU | |||||
| domi::TaskDef *task_def3 = model_task_def->add_task(); | |||||
| task_def3->set_type(RT_MODEL_TASK_ALL_KERNEL); | |||||
| string model_data_str = "123456789"; | string model_data_str = "123456789"; | ||||
| SingleOpModel model("model", model_data_str.c_str(), model_data_str.size()); | SingleOpModel model("model", model_data_str.c_str(), model_data_str.size()); | ||||
| std::mutex stream_mu; | std::mutex stream_mu; | ||||
| @@ -329,8 +338,18 @@ TEST_F(UtestSingleOpModel, build_aicpu_task) { | |||||
| DynamicSingleOp single_op(0, &stream_mu, stream); | DynamicSingleOp single_op(0, &stream_mu, stream); | ||||
| model.model_helper_.model_ = ge_model; | model.model_helper_.model_ = ge_model; | ||||
| auto op_desc = std::make_shared<ge::OpDesc>("add", "Add"); | auto op_desc = std::make_shared<ge::OpDesc>("add", "Add"); | ||||
| AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF"); | |||||
| std::vector<char> kernelBin; | |||||
| TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin)); | |||||
| op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); | |||||
| NodePtr node = graph->AddNode(op_desc); | NodePtr node = graph->AddNode(op_desc); | ||||
| model.op_list_[0] = node; | model.op_list_[0] = node; | ||||
| StreamResource *res = new (std::nothrow) StreamResource(1); | StreamResource *res = new (std::nothrow) StreamResource(1); | ||||
| ASSERT_EQ(model.ParseTasks(), SUCCESS); | |||||
| ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS); | |||||
| model.tbe_tasks_.clear(); | |||||
| ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS); | ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS); | ||||
| model.aicpu_tasks_[0] = *task_def2; | |||||
| model.BuildTaskListForDynamicOp(res, single_op); | |||||
| } | } | ||||
| @@ -54,6 +54,7 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { | |||||
| auto graph = make_shared<ComputeGraph>("graph"); | auto graph = make_shared<ComputeGraph>("graph"); | ||||
| auto op_desc = make_shared<OpDesc>("Add", "Add"); | auto op_desc = make_shared<OpDesc>("Add", "Add"); | ||||
| AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF"); | |||||
| std::vector<char> kernelBin; | std::vector<char> kernelBin; | ||||
| TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin)); | TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin)); | ||||
| op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); | op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); | ||||
| @@ -38,6 +38,7 @@ static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callba | |||||
| static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type | static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type | ||||
| static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle | static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle | ||||
| static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type | static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type | ||||
| static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout | |||||
| static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support | static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support | ||||
| static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error | static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error | ||||
| @@ -50,6 +51,7 @@ static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no eve | |||||
| static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource | static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource | ||||
| static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource | static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource | ||||
| static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource | static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource | ||||
| static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource | |||||
| static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error | static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error | ||||
| static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error | static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error | ||||
| @@ -85,9 +87,14 @@ static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug | |||||
| static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context | static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context | ||||
| static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out | static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out | ||||
| static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error | static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error | ||||
| static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout | |||||
| static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception | |||||
| static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception | |||||
| static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal | |||||
| static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error | static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error | ||||
| static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error | static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error | ||||
| static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -156,7 +156,7 @@ RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtSt | |||||
| /** | /** | ||||
| * @ingroup profiling_base | * @ingroup profiling_base | ||||
| * @brief ts send keypoint for step info. | |||||
| * @brief ts send keypoint profiler log. | |||||
| */ | */ | ||||
| RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream); | RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream); | ||||
| @@ -206,7 +206,7 @@ RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCal | |||||
| /** | /** | ||||
| * @ingroup dvrt_base | * @ingroup dvrt_base | ||||
| * @brief register callback for fail task | |||||
| * @brief register callback for fail task | |||||
| * @param [in] uniName unique register name, can't be null | * @param [in] uniName unique register name, can't be null | ||||
| * @param [in] callback fail task callback function | * @param [in] callback fail task callback function | ||||
| * @param [out] NA | * @param [out] NA | ||||
| @@ -345,11 +345,11 @@ RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream); | |||||
| * @return RT_ERROR_NONE for ok | * @return RT_ERROR_NONE for ok | ||||
| * @return RT_ERROR_INVALID_VALUE for error input | * @return RT_ERROR_INVALID_VALUE for error input | ||||
| */ | */ | ||||
| rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream); | |||||
| RTS_API rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream); | |||||
| /** | /** | ||||
| * @ingroup dvrt_base | * @ingroup dvrt_base | ||||
| * @brief get current thread last stream id and task id | |||||
| * @brief get current thread last stream id and task id | |||||
| * @param [out] stream id and task id | * @param [out] stream id and task id | ||||
| * @param [in] null | * @param [in] null | ||||
| * @return RT_ERROR_NONE for ok | * @return RT_ERROR_NONE for ok | ||||
| @@ -46,6 +46,12 @@ typedef enum tagRtChipType { | |||||
| CHIP_END, | CHIP_END, | ||||
| } rtChipType_t; | } rtChipType_t; | ||||
| typedef enum tagRtAicpuScheType { | |||||
| SCHEDULE_SOFTWARE = 0, /* Software Schedule */ | |||||
| SCHEDULE_SOFTWARE_OPT, | |||||
| SCHEDULE_HARDWARE, /* HWTS Schedule */ | |||||
| } rtAicpuScheType; | |||||
| typedef enum tagRtVersion { | typedef enum tagRtVersion { | ||||
| VER_BEGIN = 0, | VER_BEGIN = 0, | ||||
| VER_NA = VER_BEGIN, | VER_NA = VER_BEGIN, | ||||
| @@ -65,6 +71,7 @@ typedef enum tagRtPlatformType { | |||||
| PLATFORM_LHISI_CS, | PLATFORM_LHISI_CS, | ||||
| PLATFORM_DC, | PLATFORM_DC, | ||||
| PLATFORM_CLOUD_V2, | PLATFORM_CLOUD_V2, | ||||
| PLATFORM_LHISI_SD3403, | |||||
| PLATFORM_END, | PLATFORM_END, | ||||
| } rtPlatformType_t; | } rtPlatformType_t; | ||||
| @@ -126,6 +133,11 @@ typedef struct tagRtPlatformConfig { | |||||
| uint32_t platformConfig; | uint32_t platformConfig; | ||||
| } rtPlatformConfig_t; | } rtPlatformConfig_t; | ||||
| typedef enum tagRTTaskTimeoutType { | |||||
| RT_TIMEOUT_TYPE_OP_WAIT = 0, | |||||
| RT_TIMEOUT_TYPE_OP_EXECUTE, | |||||
| } rtTaskTimeoutType_t; | |||||
| /** | /** | ||||
| * @ingroup | * @ingroup | ||||
| * @brief get AI core count | * @brief get AI core count | ||||
| @@ -184,6 +196,37 @@ RTS_API rtError_t rtMemGetL2Info(rtStream_t stream, void **ptr, uint32_t *size); | |||||
| */ | */ | ||||
| RTS_API rtError_t rtGetRuntimeVersion(uint32_t *runtimeVersion); | RTS_API rtError_t rtGetRuntimeVersion(uint32_t *runtimeVersion); | ||||
| /** | |||||
| * @ingroup | |||||
| * @brief get device feature ability by device id, such as task schedule ability. | |||||
| * @param [in] deviceId | |||||
| * @param [in] moduleType | |||||
| * @param [in] featureType | |||||
| * @param [out] value | |||||
| * @return RT_ERROR_NONE for ok | |||||
| * @return RT_ERROR_INVALID_VALUE for error input | |||||
| */ | |||||
| RTS_API rtError_t rtGetDeviceCapability(int32_t deviceId, int32_t moduleType, int32_t featureType, int32_t *value); | |||||
| /** | |||||
| * @ingroup | |||||
| * @brief set event wait task timeout time. | |||||
| * @param [in] timeout | |||||
| * @return RT_ERROR_NONE for ok | |||||
| * @return RT_ERROR_INVALID_VALUE for error input | |||||
| */ | |||||
| RTS_API rtError_t rtSetOpWaitTimeOut(uint32_t timeout); | |||||
| /** | |||||
| * @ingroup | |||||
| * @brief set op execute task timeout time. | |||||
| * @param [in] timeout | |||||
| * @return RT_ERROR_NONE for ok | |||||
| * @return RT_ERROR_INVALID_VALUE for error input | |||||
| */ | |||||
| RTS_API rtError_t rtSetOpExecuteTimeOut(uint32_t timeout); | |||||
| #if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) | #if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -63,6 +63,11 @@ typedef enum tagRtFeatureType { | |||||
| FEATURE_TYPE_RSV | FEATURE_TYPE_RSV | ||||
| } rtFeatureType_t; | } rtFeatureType_t; | ||||
| typedef enum tagRtDeviceFeatureType { | |||||
| FEATURE_TYPE_SCHE, | |||||
| FEATURE_TYPE_END, | |||||
| } rtDeviceFeatureType_t; | |||||
| typedef enum tagMemcpyInfo { | typedef enum tagMemcpyInfo { | ||||
| MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, | MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, | ||||
| MEMCPY_INFO_RSV | MEMCPY_INFO_RSV | ||||
| @@ -23,12 +23,23 @@ | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| typedef enum rtEventWaitStatus { | |||||
| EVENT_STATUS_COMPLETE = 0, | |||||
| EVENT_STATUS_NOT_READY = 1, | |||||
| EVENT_STATUS_MAX = 2, | |||||
| } rtEventWaitStatus_t; | |||||
| /** | /** | ||||
| * @ingroup event_flags | * @ingroup event_flags | ||||
| * @brief event op bit flags | * @brief event op bit flags | ||||
| */ | */ | ||||
| #define RT_EVENT_DEFAULT (0x00) | |||||
| #define RT_EVENT_WITH_FLAG (0x01) | |||||
| #define RT_EVENT_DEFAULT (0x0E) | |||||
| #define RT_EVENT_WITH_FLAG (0x0B) | |||||
| #define RT_EVENT_DDSYNC_NS 0x01U | |||||
| #define RT_EVENT_STREAM_MARK 0x02U | |||||
| #define RT_EVENT_DDSYNC 0x04U | |||||
| #define RT_EVENT_TIME_LINE 0x08U | |||||
| /** | /** | ||||
| * @ingroup dvrt_event | * @ingroup dvrt_event | ||||
| @@ -104,6 +115,16 @@ RTS_API rtError_t rtEventSynchronize(rtEvent_t event); | |||||
| */ | */ | ||||
| RTS_API rtError_t rtEventQuery(rtEvent_t event); | RTS_API rtError_t rtEventQuery(rtEvent_t event); | ||||
| /** | |||||
| * @ingroup dvrt_event | |||||
| * @brief Queries an event's wait status | |||||
| * @param [in] event event to query | |||||
| * @param [in out] EVENT_WAIT_STATUS status | |||||
| * @return EVENT_STATUS_COMPLETE for complete | |||||
| * @return EVENT_STATUS_NOT_READY for not complete | |||||
| */ | |||||
| RTS_API rtError_t rtEventQueryWaitStatus(rtEvent_t event, rtEventWaitStatus_t *status); | |||||
| /** | /** | ||||
| * @ingroup dvrt_event | * @ingroup dvrt_event | ||||
| * @brief computes the elapsed time between events. | * @brief computes the elapsed time between events. | ||||
| @@ -176,6 +197,18 @@ RTS_API rtError_t rtNotifyRecord(rtNotify_t notify, rtStream_t stream); | |||||
| */ | */ | ||||
| RTS_API rtError_t rtNotifyWait(rtNotify_t notify, rtStream_t stream); | RTS_API rtError_t rtNotifyWait(rtNotify_t notify, rtStream_t stream); | ||||
| /** | |||||
| * @ingroup dvrt_event | |||||
| * @brief Wait for a notify with time out | |||||
| * @param [in] notify_ notify to be wait | |||||
| * @param [in] stream_ input stream | |||||
| * @param [in] timeOut input timeOut | |||||
| * @return RT_ERROR_NONE for ok | |||||
| * @return RT_ERROR_INVALID_VALUE for error input | |||||
| * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx | |||||
| */ | |||||
| RTS_API rtError_t rtNotifyWaitWithTimeOut(rtNotify_t notify_, rtStream_t stream_, uint32_t timeOut); | |||||
| /** | /** | ||||
| * @ingroup dvrt_event | * @ingroup dvrt_event | ||||
| * @brief Name a notify | * @brief Name a notify | ||||
| @@ -111,6 +111,16 @@ typedef struct rtKernelInfo { | |||||
| uint32_t module_size; | uint32_t module_size; | ||||
| } *rtKernelInfo_t; | } *rtKernelInfo_t; | ||||
| /** | |||||
| * @ingroup rt_kernel | |||||
| * @brief op name | |||||
| */ | |||||
| typedef struct rtKernelLaunchNames { | |||||
| const char *soName; // defined for so name | |||||
| const char *kernelName; // defined for kernel type name | |||||
| const char *opName; // defined for operator name | |||||
| } rtKernelLaunchNames_t; | |||||
| /** | /** | ||||
| * @ingroup rt_KernelConfigDump | * @ingroup rt_KernelConfigDump | ||||
| * @brief device dump type | * @brief device dump type | ||||
| @@ -173,13 +183,7 @@ typedef void (*rtCallback_t)(void *fnData); | |||||
| * @ingroup rt_kernel | * @ingroup rt_kernel | ||||
| * @brief magic number of elf binary for aicube | * @brief magic number of elf binary for aicube | ||||
| */ | */ | ||||
| #define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41415247 | |||||
| /** | |||||
| * @ingroup rt_kernel | |||||
| * @brief magic number of elf binary for aivector | |||||
| */ | |||||
| #define RT_DEV_BINARY_MAGIC_ELF_AIVECTOR 0x41415248 | |||||
| #define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41494343 | |||||
| /** | /** | ||||
| * @ingroup rt_kernel_flags | * @ingroup rt_kernel_flags | ||||
| @@ -192,14 +196,14 @@ typedef void (*rtCallback_t)(void *fnData); | |||||
| #define RT_KERNEL_CUSTOM_AICPU (0x08) | #define RT_KERNEL_CUSTOM_AICPU (0x08) | ||||
| // STARS topic scheduler sqe : topic_type | // STARS topic scheduler sqe : topic_type | ||||
| #define RT_KERNEL_DEVICE_FIRST (0X10) | |||||
| #define RT_KERNEL_HOST_ONLY (0X20) | |||||
| #define RT_KERNEL_HOST_FIRST (0X30) | |||||
| #define RT_KERNEL_DEVICE_FIRST (0x10) | |||||
| #define RT_KERNEL_HOST_ONLY (0x20) | |||||
| #define RT_KERNEL_HOST_FIRST (0x40) | |||||
| /** | /** | ||||
| * @ingroup rt_kernel | * @ingroup rt_kernel | ||||
| * @brief kernel mode | * @brief kernel mode | ||||
| */ | |||||
| **/ | |||||
| #define RT_DEFAULT_KERNEL_MODE (0x00) | #define RT_DEFAULT_KERNEL_MODE (0x00) | ||||
| #define RT_NORMAL_KERNEL_MODE (0x01) | #define RT_NORMAL_KERNEL_MODE (0x01) | ||||
| #define RT_ALL_KERNEL_MODE (0x02) | #define RT_ALL_KERNEL_MODE (0x02) | ||||
| @@ -222,7 +226,7 @@ RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **handle); | |||||
| /** | /** | ||||
| * @ingroup rt_kernel | * @ingroup rt_kernel | ||||
| * @brief register device binary | |||||
| * @brief register device binary with all kernel | |||||
| * @param [in] bin device binary description | * @param [in] bin device binary description | ||||
| * @param [out] handle device binary handle | * @param [out] handle device binary handle | ||||
| * @return RT_ERROR_NONE for ok | * @return RT_ERROR_NONE for ok | ||||
| @@ -341,7 +345,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void * | |||||
| * @ingroup rt_kernel | * @ingroup rt_kernel | ||||
| * @brief launch kernel with handle to device | * @brief launch kernel with handle to device | ||||
| * @param [in] handle program | * @param [in] handle program | ||||
| * @param [in] devFunc device function description | |||||
| * @param [in] devFunc device function description. | |||||
| * @param [in] blockDim block dimentions | * @param [in] blockDim block dimentions | ||||
| * @param [in] args argments address for kernel function | * @param [in] args argments address for kernel function | ||||
| * @param [in] argsSize argements size | * @param [in] argsSize argements size | ||||
| @@ -352,7 +356,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void * | |||||
| * @return RT_ERROR_INVALID_VALUE for error input | * @return RT_ERROR_INVALID_VALUE for error input | ||||
| */ | */ | ||||
| RTS_API rtError_t rtKernelLaunchWithHandle(void *handle, const void *devFunc, uint32_t blockDim, void *args, uint32_t argsSize, | RTS_API rtError_t rtKernelLaunchWithHandle(void *handle, const void *devFunc, uint32_t blockDim, void *args, uint32_t argsSize, | ||||
| rtSmDesc_t *smDesc, rtStream_t stream, const void *kernelInfo); | |||||
| rtSmDesc_t *smDesc, rtStream_t stream_, const void *kernelInfo); | |||||
| /** | /** | ||||
| * @ingroup rt_kernel | * @ingroup rt_kernel | ||||
| @@ -371,7 +375,7 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim | |||||
| rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); | rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); | ||||
| /** | /** | ||||
| * @ingroup rt_kernel | |||||
| * @ingroup rt_kernel(abandoned) | |||||
| * @brief launch kernel to device | * @brief launch kernel to device | ||||
| * @param [in] args argments address for kernel function | * @param [in] args argments address for kernel function | ||||
| * @param [in] argsSize argements size | * @param [in] argsSize argements size | ||||
| @@ -383,7 +387,21 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim | |||||
| RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream); | RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream); | ||||
| /** | /** | ||||
| * @ingroup rt_kernel | |||||
| * @ingroup rt_kernel(in use) | |||||
| * @brief launch kernel to device | |||||
| * @param [in] opName opkernel name | |||||
| * @param [in] args argments address for kernel function | |||||
| * @param [in] argsSize argements size | |||||
| * @param [in] flags launch flags | |||||
| * @param [in] stream associated stream | |||||
| * @return RT_ERROR_NONE for ok | |||||
| * @return RT_ERROR_INVALID_VALUE for error input | |||||
| */ | |||||
| RTS_API rtError_t rtKernelLaunchFwk(const char *opName, void *args, uint32_t argsSize, uint32_t flags, | |||||
| rtStream_t rtStream); | |||||
| /** | |||||
| * @ingroup rt_kernel(abandoned) | |||||
| * @brief launch cpu kernel to device | * @brief launch cpu kernel to device | ||||
| * @param [in] soName so name | * @param [in] soName so name | ||||
| * @param [in] kernelName kernel name | * @param [in] kernelName kernel name | ||||
| @@ -399,7 +417,22 @@ RTS_API rtError_t rtCpuKernelLaunch(const void *soName, const void *kernelName, | |||||
| uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); | uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); | ||||
| /** | /** | ||||
| * @ingroup rt_kernel | |||||
| * @ingroup rt_kernel(in use) | |||||
| * @brief launch cpu kernel to device | |||||
| * @param [in] launchNames names for kernel launch | |||||
| * @param [in] blockDim block dimentions | |||||
| * @param [in] args argments address for kernel function | |||||
| * @param [in] argsSize argments size | |||||
| * @param [in] smDesc shared memory description | |||||
| * @param [in] stream associated stream | |||||
| * @return RT_ERROR_NONE for ok | |||||
| * @return RT_ERROR_INVALID_VALUE for error input | |||||
| */ | |||||
| RTS_API rtError_t rtAicpuKernelLaunch(const rtKernelLaunchNames_t *launchNames, | |||||
| uint32_t blockDim, const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); | |||||
| /** | |||||
| * @ingroup rt_kernel(abandoned) | |||||
| * @brief launch cpu kernel to device with dump identifier | * @brief launch cpu kernel to device with dump identifier | ||||
| * @param [in] soName so name | * @param [in] soName so name | ||||
| * @param [in] kernelName kernel name | * @param [in] kernelName kernel name | ||||
| @@ -416,6 +449,22 @@ RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kern | |||||
| const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, | const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, | ||||
| uint32_t flags); | uint32_t flags); | ||||
| /** | |||||
| * @ingroup rt_kernel(in use) | |||||
| * @brief launch cpu kernel to device with dump identifier | |||||
| * @param [in] launchNames names for kernel launch | |||||
| * @param [in] blockDim block dimentions | |||||
| * @param [in] args argments address for kernel function | |||||
| * @param [in] argsSize argments size | |||||
| * @param [in] smDesc shared memory description | |||||
| * @param [in] stream associated stream | |||||
| * @param [in] flag dump flag or others function flag | |||||
| * @return RT_ERROR_NONE for ok | |||||
| * @return RT_ERROR_INVALID_VALUE for error input | |||||
| */ | |||||
| RTS_API rtError_t rtAicpuKernelLaunchWithFlag(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim, | |||||
| const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); | |||||
| /** | /** | ||||
| * @ingroup rt_kernel | * @ingroup rt_kernel | ||||
| * @brief L1 fusion dump addr transfered to device | * @brief L1 fusion dump addr transfered to device | ||||
| @@ -116,6 +116,9 @@ typedef enum tagRtMemInfoType { | |||||
| typedef enum tagRtRecudeKind { | typedef enum tagRtRecudeKind { | ||||
| RT_MEMCPY_SDMA_AUTOMATIC_ADD = 10, // D2D, SDMA inline reduce, include 1P, and P2P | RT_MEMCPY_SDMA_AUTOMATIC_ADD = 10, // D2D, SDMA inline reduce, include 1P, and P2P | ||||
| RT_MEMCPY_SDMA_AUTOMATIC_MAX = 11, | |||||
| RT_MEMCPY_SDMA_AUTOMATIC_MIN = 12, | |||||
| RT_MEMCPY_SDMA_AUTOMATIC_EQUAL = 13, | |||||
| RT_RECUDE_KIND_END | RT_RECUDE_KIND_END | ||||
| } rtRecudeKind_t; | } rtRecudeKind_t; | ||||
| @@ -123,6 +126,14 @@ typedef enum tagRtDataType { | |||||
| RT_DATA_TYPE_FP32 = 0, // fp32 | RT_DATA_TYPE_FP32 = 0, // fp32 | ||||
| RT_DATA_TYPE_FP16 = 1, // fp16 | RT_DATA_TYPE_FP16 = 1, // fp16 | ||||
| RT_DATA_TYPE_INT16 = 2, // int16 | RT_DATA_TYPE_INT16 = 2, // int16 | ||||
| RT_DATA_TYPE_INT4 = 3, // int4 | |||||
| RT_DATA_TYPE_INT8 = 4, // int8 | |||||
| RT_DATA_TYPE_INT32 = 5, // int32 | |||||
| RT_DATA_TYPE_BFP16 = 6, // bfp16 | |||||
| RT_DATA_TYPE_BFP32 = 7, // bfp32 | |||||
| RT_DATA_TYPE_UINT8 = 8, // uint8 | |||||
| RT_DATA_TYPE_UINT16= 9, // uint16 | |||||
| RT_DATA_TYPE_UINT32= 10,// uint32 | |||||
| RT_DATA_TYPE_END | RT_DATA_TYPE_END | ||||
| } rtDataType_t; | } rtDataType_t; | ||||
| @@ -135,12 +135,13 @@ typedef struct tagAllKernelTaskInfo { | |||||
| uint16_t argsCount; | uint16_t argsCount; | ||||
| uint16_t argsSize; | uint16_t argsSize; | ||||
| uint16_t reserved; | uint16_t reserved; | ||||
| const void *dev_func; | |||||
| void *devfunc; | |||||
| void *handle; | void *handle; | ||||
| uint8_t *smDesc; | uint8_t *smDesc; | ||||
| uint8_t *args; | uint8_t *args; | ||||
| uint16_t *argsOffset; | uint16_t *argsOffset; | ||||
| } rtAllKernelTaskInfo_t; | } rtAllKernelTaskInfo_t; | ||||
| typedef struct tagKernelTaskInfoEx { | typedef struct tagKernelTaskInfoEx { | ||||
| uint32_t flags; | uint32_t flags; | ||||
| uint32_t argsSize; | uint32_t argsSize; | ||||
| @@ -198,6 +199,13 @@ typedef struct tagProfilerTraceTaskInfo { | |||||
| uint32_t reserved[6]; | uint32_t reserved[6]; | ||||
| } rtProfilerTrace_t; | } rtProfilerTrace_t; | ||||
| typedef struct tagProfilerTraceExTaskInfo { | |||||
| uint64_t profilerTraceId; | |||||
| uint64_t modelId; | |||||
| uint16_t tagId; | |||||
| uint8_t reserved[22]; | |||||
| } rtProfilerTraceEx_t; | |||||
| typedef struct tagrtMemcpyAsyncTaskInfo { | typedef struct tagrtMemcpyAsyncTaskInfo { | ||||
| void *dst; | void *dst; | ||||
| uint64_t destMax; | uint64_t destMax; | ||||
| @@ -265,7 +273,7 @@ typedef struct tagTaskInfo { | |||||
| union { | union { | ||||
| rtKernelTaskInfoEx_t kernelTaskEx; | rtKernelTaskInfoEx_t kernelTaskEx; | ||||
| rtKernelTaskInfo_t kernelTask; | rtKernelTaskInfo_t kernelTask; | ||||
| rtAllKernelTaskInfo_t allkernelTask; | |||||
| rtAllKernelTaskInfo_t allKernelTask; | |||||
| rtEventTaskInfo_t eventTask; | rtEventTaskInfo_t eventTask; | ||||
| rtStreamSwitchTaskInfo_t streamSwitchTask; | rtStreamSwitchTaskInfo_t streamSwitchTask; | ||||
| rtStreamActiveTaskInfo_t streamActiveTask; | rtStreamActiveTaskInfo_t streamActiveTask; | ||||
| @@ -273,6 +281,7 @@ typedef struct tagTaskInfo { | |||||
| rtLabelSwitchTaskInfo_t labelSwitchTask; | rtLabelSwitchTaskInfo_t labelSwitchTask; | ||||
| rtLabelGotoTaskInfo_t labelGotoTask; | rtLabelGotoTaskInfo_t labelGotoTask; | ||||
| rtProfilerTrace_t profilertraceTask; | rtProfilerTrace_t profilertraceTask; | ||||
| rtProfilerTraceEx_t profilertraceExTask; | |||||
| rtMemcpyAsyncTaskInfo_t memcpyAsyncTask; | rtMemcpyAsyncTaskInfo_t memcpyAsyncTask; | ||||
| rtNotifyTaskInfo_t notifyTask; | rtNotifyTaskInfo_t notifyTask; | ||||
| rtReduceAsyncTaskInfo_t reduceAsyncTask; | rtReduceAsyncTaskInfo_t reduceAsyncTask; | ||||
| @@ -108,7 +108,19 @@ enum MsprofCtrlCallbackType { | |||||
| MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env | MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env | ||||
| MSPROF_CTRL_INIT_ACL_JSON, // start profiling with acl.json | MSPROF_CTRL_INIT_ACL_JSON, // start profiling with acl.json | ||||
| MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options | MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options | ||||
| MSPROF_CTRL_FINALIZE // stop profiling | |||||
| MSPROF_CTRL_FINALIZE, // stop profiling | |||||
| MSPROF_CTRL_REPORT_FUN_P, // for report callback | |||||
| MSPROF_CTRL_PROF_SWITCH_ON, // for prof switch on | |||||
| MSPROF_CTRL_PROF_SWITCH_OFF // for prof switch off | |||||
| }; | |||||
| #define MSPROF_MAX_DEV_NUM (64) | |||||
| struct MsprofCommandHandle { | |||||
| uint64_t profSwitch; | |||||
| uint32_t devNums; // length of device id list | |||||
| uint32_t devIdList[MSPROF_MAX_DEV_NUM]; | |||||
| uint32_t modelId; | |||||
| }; | }; | ||||
| /** | /** | ||||
| @@ -129,6 +141,23 @@ typedef int32_t (*MsprofCtrlCallback)(uint32_t type, void *data, uint32_t len); | |||||
| */ | */ | ||||
| typedef void (*MsprofSetDeviceCallback)(uint32_t devId, bool isOpenDevice); | typedef void (*MsprofSetDeviceCallback)(uint32_t devId, bool isOpenDevice); | ||||
| /* | |||||
| * @name MsprofInit | |||||
| * @brief Profiling module init | |||||
| * @param [in] dataType: profiling type: ACL Env/ACL Json/GE Option | |||||
| * @param [in] data: profiling switch data | |||||
| * @param [in] dataLen: Length of data | |||||
| * @return 0:SUCCESS, >0:FAILED | |||||
| */ | |||||
| int32_t MsprofInit(uint32_t dataType, void *data, uint32_t dataLen); | |||||
| /* | |||||
| * @name AscendCL | |||||
| * @brief Finishing Profiling | |||||
| * @param NULL | |||||
| * @return 0:SUCCESS, >0:FAILED | |||||
| */ | |||||
| int32_t MsprofFinalize(); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -17,6 +17,8 @@ | |||||
| #ifndef D_SYSLOG_H_ | #ifndef D_SYSLOG_H_ | ||||
| #define D_SYSLOG_H_ | #define D_SYSLOG_H_ | ||||
| static const int TMP_LOG = 0; | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| #ifndef LOG_CPP | #ifndef LOG_CPP | ||||
| extern "C" { | extern "C" { | ||||
| @@ -120,15 +122,15 @@ typedef struct tagKV { | |||||
| } KeyValue; | } KeyValue; | ||||
| typedef enum { | typedef enum { | ||||
| APPLICATION = 0, | |||||
| SYSTEM | |||||
| APPLICATION = 0, | |||||
| SYSTEM | |||||
| } ProcessType; | } ProcessType; | ||||
| typedef struct { | typedef struct { | ||||
| ProcessType type; | |||||
| unsigned int pid; | |||||
| unsigned int deviceId; | |||||
| char reserved[RESERVERD_LENGTH]; | |||||
| ProcessType type; | |||||
| unsigned int pid; | |||||
| unsigned int deviceId; | |||||
| char reserved[RESERVERD_LENGTH]; | |||||
| } LogAttr; | } LogAttr; | ||||
| /** | /** | ||||
| @@ -141,7 +143,7 @@ enum { | |||||
| IDEDD, /**< IDE daemon device */ | IDEDD, /**< IDE daemon device */ | ||||
| IDEDH, /**< IDE daemon host */ | IDEDH, /**< IDE daemon host */ | ||||
| HCCL, /**< HCCL */ | HCCL, /**< HCCL */ | ||||
| FMK, /**< Framework */ | |||||
| FMK, /**< Adapter */ | |||||
| HIAIENGINE, /**< Matrix */ | HIAIENGINE, /**< Matrix */ | ||||
| DVPP, /**< DVPP */ | DVPP, /**< DVPP */ | ||||
| RUNTIME, /**< Runtime */ | RUNTIME, /**< Runtime */ | ||||
| @@ -162,11 +164,11 @@ enum { | |||||
| MDCDEFAULT, /**< MDC undefine */ | MDCDEFAULT, /**< MDC undefine */ | ||||
| MDCSC, /**< MDC spatial cognition */ | MDCSC, /**< MDC spatial cognition */ | ||||
| MDCPNC, | MDCPNC, | ||||
| MLL, | |||||
| MLL, /**< abandon */ | |||||
| DEVMM, /**< Dlog memory managent */ | DEVMM, /**< Dlog memory managent */ | ||||
| KERNEL, /**< Kernel */ | KERNEL, /**< Kernel */ | ||||
| LIBMEDIA, /**< Libmedia */ | LIBMEDIA, /**< Libmedia */ | ||||
| CCECPU, /**< ai cpu */ | |||||
| CCECPU, /**< aicpu shedule */ | |||||
| ASCENDDK, /**< AscendDK */ | ASCENDDK, /**< AscendDK */ | ||||
| ROS, /**< ROS */ | ROS, /**< ROS */ | ||||
| HCCP, | HCCP, | ||||
| @@ -179,7 +181,7 @@ enum { | |||||
| TSDUMP, /**< TSDUMP module */ | TSDUMP, /**< TSDUMP module */ | ||||
| AICPU, /**< AICPU module */ | AICPU, /**< AICPU module */ | ||||
| LP, /**< LP module */ | LP, /**< LP module */ | ||||
| TDT, | |||||
| TDT, /**< tsdaemon or aicpu shedule */ | |||||
| FE, | FE, | ||||
| MD, | MD, | ||||
| MB, | MB, | ||||
| @@ -261,7 +263,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); | |||||
| #define dlog_error(moduleId, fmt, ...) \ | #define dlog_error(moduleId, fmt, ...) \ | ||||
| do { \ | do { \ | ||||
| DlogErrorInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogErrorInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -276,7 +278,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); | |||||
| if(CheckLogLevel(moduleId, DLOG_WARN) == 1) { \ | if(CheckLogLevel(moduleId, DLOG_WARN) == 1) { \ | ||||
| DlogWarnInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogWarnInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -291,7 +293,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); | |||||
| if(CheckLogLevel(moduleId, DLOG_INFO) == 1) { \ | if(CheckLogLevel(moduleId, DLOG_INFO) == 1) { \ | ||||
| DlogInfoInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogInfoInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -306,7 +308,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); | |||||
| if(CheckLogLevel(moduleId, DLOG_DEBUG) == 1) { \ | if(CheckLogLevel(moduleId, DLOG_DEBUG) == 1) { \ | ||||
| DlogDebugInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogDebugInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -318,7 +320,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); | |||||
| #define dlog_event(moduleId, fmt, ...) \ | #define dlog_event(moduleId, fmt, ...) \ | ||||
| do { \ | do { \ | ||||
| DlogEventInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogEventInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -334,7 +336,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); | |||||
| if(CheckLogLevel(moduleId, level) == 1) { \ | if(CheckLogLevel(moduleId, level) == 1) { \ | ||||
| DlogInner(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogInner(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -351,7 +353,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); | |||||
| if(CheckLogLevel(moduleId, level) == 1) { \ | if(CheckLogLevel(moduleId, level) == 1) { \ | ||||
| DlogInner(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ | DlogInner(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -369,7 +371,7 @@ DLL_EXPORT int DlogSetAttr(LogAttr logAttr); | |||||
| if(CheckLogLevel(moduleId, level) == 1) { \ | if(CheckLogLevel(moduleId, level) == 1) { \ | ||||
| DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -381,13 +383,13 @@ DLL_EXPORT void DlogFlush(void); | |||||
| * @ingroup slog | * @ingroup slog | ||||
| * @brief Internal log interface, other modules are not allowed to call this interface | * @brief Internal log interface, other modules are not allowed to call this interface | ||||
| */ | */ | ||||
| void DlogErrorInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); | |||||
| void DlogWarnInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); | |||||
| void DlogInfoInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); | |||||
| void DlogDebugInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); | |||||
| void DlogEventInner(int moduleId, const char *fmt, ...) __attribute__((format(printf, 2, 3))); | |||||
| void DlogInner(int moduleId, int level, const char *fmt, ...) __attribute__((format(printf, 3, 4))); | |||||
| void DlogWithKVInner(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...) __attribute__((format(printf, 5, 6))); | |||||
| void DlogErrorInner(int moduleId, const char *fmt, ...); | |||||
| void DlogWarnInner(int moduleId, const char *fmt, ...); | |||||
| void DlogInfoInner(int moduleId, const char *fmt, ...); | |||||
| void DlogDebugInner(int moduleId, const char *fmt, ...); | |||||
| void DlogEventInner(int moduleId, const char *fmt, ...); | |||||
| void DlogInner(int moduleId, int level, const char *fmt, ...); | |||||
| void DlogWithKVInner(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| #ifndef LOG_CPP | #ifndef LOG_CPP | ||||
| @@ -453,7 +455,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr); | |||||
| if(CheckLogLevelForC(moduleId, level) == 1) { \ | if(CheckLogLevelForC(moduleId, level) == 1) { \ | ||||
| DlogInnerForC(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogInnerForC(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -470,7 +472,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr); | |||||
| if(CheckLogLevelForC(moduleId, level) == 1) { \ | if(CheckLogLevelForC(moduleId, level) == 1) { \ | ||||
| DlogInnerForC(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ | DlogInnerForC(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -488,7 +490,7 @@ DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr); | |||||
| if(CheckLogLevelForC(moduleId, level) == 1) { \ | if(CheckLogLevelForC(moduleId, level) == 1) { \ | ||||
| DlogWithKVInnerForC(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | DlogWithKVInnerForC(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ | ||||
| } \ | } \ | ||||
| } while (0) | |||||
| } while (TMP_LOG != 0) | |||||
| /** | /** | ||||
| * @ingroup slog | * @ingroup slog | ||||
| @@ -500,8 +502,8 @@ DLL_EXPORT void DlogFlushForC(void); | |||||
| * @ingroup slog | * @ingroup slog | ||||
| * @brief Internal log interface, other modules are not allowed to call this interface | * @brief Internal log interface, other modules are not allowed to call this interface | ||||
| */ | */ | ||||
| void DlogInnerForC(int moduleId, int level, const char *fmt, ...) __attribute__((format(printf, 3, 4))); | |||||
| void DlogWithKVInnerForC(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...) __attribute__((format(printf, 5, 6))); | |||||
| void DlogInnerForC(int moduleId, int level, const char *fmt, ...); | |||||
| void DlogWithKVInnerForC(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -1,72 +1,88 @@ | |||||
| /** | |||||
| * @file tune_api.h | |||||
| * | |||||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.\n | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n | |||||
| * 描述:mstune调优接口头文件 | |||||
| */ | |||||
| /** @defgroup mstune mstune调优接口 */ | |||||
| #ifndef TUNE_API_H | |||||
| #define TUNE_API_H | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "graph/graph.h" | |||||
| #include "ge/ge_api.h" | |||||
| /** | |||||
| * @ingroup mstune | |||||
| * | |||||
| * mstune status | |||||
| */ | |||||
| enum MsTuneStatus { | |||||
| MSTUNE_SUCCESS, /** tune success */ | |||||
| MSTUNE_FAILED, /** tune failed */ | |||||
| }; | |||||
| // Option key: for train options sets | |||||
| const std::string MSTUNE_SELF_KEY = "mstune"; | |||||
| const std::string MSTUNE_GEINIT_KEY = "initialize"; | |||||
| const std::string MSTUNE_GESESS_KEY = "session"; | |||||
| /** | |||||
| * @ingroup mstune | |||||
| * @par 描述: 命令行调优 | |||||
| * | |||||
| * @attention 无 | |||||
| * @param option [IN] 调优参数 | |||||
| * @param msg [OUT] 调优异常下返回信息 | |||||
| * @retval #MSTUNE_SUCCESS 执行成功 | |||||
| * @retval #MSTUNE_FAILED 执行失败 | |||||
| * @par 依赖: | |||||
| * @li tune_api.cpp:该接口所属的开发包。 | |||||
| * @li tune_api.h:该接口声明所在的头文件。 | |||||
| * @see 无 | |||||
| * @since | |||||
| */ | |||||
| MsTuneStatus MsTuning(const std::map<std::string, std::string> &option, std::string &msg); | |||||
| /** | |||||
| * @ingroup mstune | |||||
| * @par 描述: 梯度调优 | |||||
| * | |||||
| * @attention 无 | |||||
| * @param tuningGraph [IN] 调优图 | |||||
| * @param dependGraph [IN] 调优依赖图 | |||||
| * @param session [IN] ge连接会话 | |||||
| * @param option [IN] 参数集. 包含调优参数及ge参数 | |||||
| * @retval #MSTUNE_SUCCESS 执行成功 | |||||
| * @retval #MSTUNE_FAILED 执行失败 | |||||
| * @par 依赖: | |||||
| * @li tune_api.cpp:该接口所属的开发包。 | |||||
| * @li tune_api.h:该接口声明所在的头文件。 | |||||
| * @see 无 | |||||
| * @since | |||||
| */ | |||||
| extern "C" MsTuneStatus MsTrainTuning(ge::Graph &tuningGraph, std::vector<ge::Graph> &dependGraph, | |||||
| ge::Session *session, const std::map<std::string, std::map<std::string, std::string>> &option); | |||||
| #endif | |||||
| /** | |||||
| * @file tune_api.h | |||||
| * | |||||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.\n | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n | |||||
| * 描述:aoe调优接口头文件 | |||||
| */ | |||||
| /** @defgroup aoe aoe调优接口 */ | |||||
| #ifndef TUNE_API_H | |||||
| #define TUNE_API_H | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "ge/ge_api.h" | |||||
| #include "aoe_types.h" | |||||
| /** | |||||
| * @ingroup aoe | |||||
| * @par 描述: 命令行调优 | |||||
| * | |||||
| * @attention 无 | |||||
| * @param option [IN] 调优参数 | |||||
| * @param msg [OUT] 调优异常下返回信息 | |||||
| * @retval #AOE_SUCCESS 执行成功 | |||||
| * @retval #AOE_FAILURE 执行失败 | |||||
| * @par 依赖: | |||||
| * @li tune_api.cpp:该接口所属的开发包。 | |||||
| * @li tune_api.h:该接口声明所在的头文件。 | |||||
| * @see 无 | |||||
| * @since | |||||
| */ | |||||
| AoeStatus AoeOfflineTuning(const std::map<std::string, std::string> &option, std::string &msg); | |||||
| /** | |||||
| * @ingroup aoe | |||||
| * @par 描述: 调优初始化 | |||||
| * | |||||
| * @attention 无 | |||||
| * @param session [IN] ge连接会话 | |||||
| * @param option [IN] 参数集. 包含调优参数及ge参数 | |||||
| * @retval #AOE_SUCCESS 执行成功 | |||||
| * @retval #AOE_FAILURE 执行失败 | |||||
| * @par 依赖: | |||||
| * @li tune_api.cpp:该接口所属的开发包。 | |||||
| * @li tune_api.h:该接口声明所在的头文件。 | |||||
| * @see 无 | |||||
| * @since | |||||
| */ | |||||
| extern "C" AoeStatus AoeOnlineInitialize(ge::Session *session, const std::map<std::string, std::string> &option); | |||||
| /** | |||||
| * @ingroup aoe | |||||
| * @par 描述: 调优去初始化 | |||||
| * | |||||
| * @attention 无 | |||||
| * @param 无 | |||||
| * @retval #AOE_SUCCESS 执行成功 | |||||
| * @retval #AOE_FAILURE 执行失败 | |||||
| * @par 依赖: | |||||
| * @li tune_api.cpp:该接口所属的开发包。 | |||||
| * @li tune_api.h:该接口声明所在的头文件。 | |||||
| * @see 无 | |||||
| * @since | |||||
| */ | |||||
| extern "C" AoeStatus AoeOnlineFinalize(); | |||||
| /** | |||||
| * @ingroup aoe | |||||
| * @par 描述: 调优处理 | |||||
| * | |||||
| * @attention 无 | |||||
| * @param tuningGraph [IN] 调优图 | |||||
| * @param dependGraph [IN] 调优依赖图 | |||||
| * @param session [IN] ge连接会话 | |||||
| * @param option [IN] 参数集. 包含调优参数及ge参数 | |||||
| * @retval #AOE_SUCCESS 执行成功 | |||||
| * @retval #AOE_FAILURE 执行失败 | |||||
| * @par 依赖: | |||||
| * @li tune_api.cpp:该接口所属的开发包。 | |||||
| * @li tune_api.h:该接口声明所在的头文件。 | |||||
| * @see 无 | |||||
| * @since | |||||
| */ | |||||
| extern "C" AoeStatus AoeOnlineTuning(ge::Graph &tuningGraph, std::vector<ge::Graph> &dependGraph, | |||||
| ge::Session *session, const std::map<std::string, std::string> &option); | |||||
| #endif | |||||