diff --git a/CMakeLists.txt b/CMakeLists.txt index ac0240d9..60509838 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,11 +88,9 @@ else () find_module(hccl libhccl.so ${GE_LIB_PATH}) find_module(adump_server libadump_server.a ${GE_LIB_PATH}) find_module(runtime libruntime.so ${GE_LIB_PATH}) - find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) find_module(resource libresource.so ${GE_LIB_PATH}) find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) - #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) else() find_module(slog libalog.so ${ASCEND_ATC_DIR}) find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) @@ -108,7 +106,6 @@ else () elseif(PLATFORM STREQUAL "inference") find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) if(PRODUCT STREQUAL "flr3") elseif(PRODUCT STREQUAL "flr1") @@ -120,10 +117,9 @@ else () endif() elseif(PLATFORM STREQUAL "all") find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) - find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_ATC_DIR}) find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_ATC_DIR}/stub) find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) else() message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") diff --git a/cmake/external_libs/gflags.cmake b/cmake/external_libs/gflags.cmake index 50cfb2bc..b4b57dd7 100755 --- a/cmake/external_libs/gflags.cmake +++ b/cmake/external_libs/gflags.cmake @@ -10,12 +10,17 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() -if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") - set(MD5 "") +if (GE_PB_PKG) + set(REQ_URL "${GE_PB_PKG}/libs/gflags/v2.2.2.tar.gz") + set(MD5 "1a865b93bacfa963201af3f75b7bd64c") else() - set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") - set(MD5 "") + if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") + set(MD5 "") + else() + set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") + set(MD5 "1a865b93bacfa963201af3f75b7bd64c") + endif () endif () set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private") diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index cd65d5c1..a614f86d 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -112,6 +112,8 @@ set(EXECUTOR_SRC_LIST "common/dump/dump_op.cc" "common/dump/exception_dumper.cc" "common/dump/opdebug_register.cc" + "common/ge/op_tiling_manager.cc" + "common/ge/plugin_manager.cc" "common/profiling/ge_profiling.cc" "common/profiling/profiling_manager.cc" "executor/ge_executor.cc" @@ -259,6 +261,8 @@ set(EXECUTOR_SRC_LIST set(COMPILER_SRC_LIST "analyzer/analyzer.cc" "common/dump/dump_op.cc" + "common/ge/op_tiling_manager.cc" + "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" @@ -619,7 +623,6 @@ target_compile_definitions(ge_compiler PRIVATE REUSE_MEMORY=1 FMK_SUPPORT_DUMP FMK_HOST_INFER - COMPILE_OMG_PACKAGE google=ascend_private FUNC_VISIBILITY $<$:ONLY_COMPILE_OPEN_SRC> @@ -681,8 +684,7 @@ target_link_libraries(ge_compiler PRIVATE c_sec error_manager slog - $<$>:$> - $<$:$> + runtime opt_feature -Wl,--as-needed json diff --git a/ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc b/ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc index 4cc0f6a5..dce800d8 100644 --- a/ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc +++ b/ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc @@ -350,7 +350,7 @@ Status FftsPlusTaskInfo::InitAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def i_cache_prefetch_cnt_2)); ctx->tailTaskStartPcL = static_cast(reinterpret_cast(tail_task_start_pc) & 0XFFFFFFFF); ctx->tailTaskStartPcH = static_cast((reinterpret_cast(tail_task_start_pc) >> 32) & 0X0000FFFF); - uint32_t i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); + uint32_t i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); ctx->icachePrefetchCnt = static_cast(i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 if (ctx_def.src_slot_size() != kSrcSlotNum) { @@ -526,8 +526,7 @@ Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &c ctx->tailAicTaskStartPcL = static_cast(reinterpret_cast(tail_aic_task_start_pc) & 0XFFFFFFFF); ctx->tailAicTaskStartPcH = static_cast((reinterpret_cast(tail_aic_task_start_pc) >> 32) & 0X0000FFFF); - uint32_t aic_i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); - // TODO + uint32_t aic_i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); ctx->icachePrefetchCnt = static_cast(aic_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 uint32_t i_cache_prefetch_cnt_3; @@ -545,9 +544,10 @@ Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &c ctx->tailAivTaskStartPcL = static_cast(reinterpret_cast(tail_aiv_task_start_pc) & 0XFFFFFFFF); ctx->tailAivTaskStartPcH = static_cast((reinterpret_cast(tail_aiv_task_start_pc) >> 32) & 0X0000FFFF); - uint32_t aiv_i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_3, i_cache_prefetch_cnt_4); + uint32_t aiv_i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_3, i_cache_prefetch_cnt_4); // TODO - ctx->icachePrefetchCnt = static_cast(aiv_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 + ctx->icachePrefetchCnt = static_cast( + std::min(aic_i_cache_prefetch_cnt, aiv_i_cache_prefetch_cnt) & 0X0000001F); // 5 bits, 0001,1111 if (ctx_def.src_slot_size() != kSrcSlotNum) { REPORT_INNER_ERROR("E19999", "Size of src_slot in FftsPlusMixAicAivCtxDef should be %d, but %d exactly", @@ -913,11 +913,11 @@ void FftsPlusTaskInfo::SetAdditionalDatatoCtx(const domi::FftsPlusTaskDef &task_ Status FftsPlusTaskInfo::UpdateMixAicAivCtxParam(const domi::FftsPlusMixAicAivCtxDef &ctx_def, size_t ctx_idx) { if (ctx_additional_data_.count(ctx_idx) == 0) { - GELOGD("ctx idx:%d not in ctx additional data"); + GELOGD("ctx idx:%zu not in ctx additional data"); return SUCCESS; } if (ctx_additional_data_[ctx_idx].count(kModeInArgsFirstField) == 0) { - GELOGD("ctx idx:%d need not to save mode in args first field"); + GELOGD("ctx idx:%zu need not to save mode in args first field"); return SUCCESS; } if (rtApp_addr_ == 0) { diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index 89a4e45b..d0669254 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -20,6 +20,7 @@ #include "graph/manager/graph_mem_manager.h" #include "graph/manager/trans_var_data_utils.h" #include "graph/utils/type_utils.h" +#include "graph/ge_context.h" using std::map; using std::string; @@ -767,25 +768,52 @@ Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &grap return var_resource_->GetChangedGraphId(var_name, graph_id); } +Status VarManager::GetTotalMemorySize(size_t &total_mem_size) { + rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", + GetContext().DeviceId(), rt_ret); + GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); + return RT_FAILED; + } + size_t free_mem = 0; + rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free_mem, &total_mem_size); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret); + return RT_FAILED; + } + rt_ret = rtDeviceReset(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", + GetContext().DeviceId(), rt_ret); + GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); + return RT_FAILED; + } + return SUCCESS; +} + Status VarManager::SetMemoryMallocSize(const map &options) { - auto it = options.find(GRAPH_MEMORY_MAX_SIZE); - if (it == options.end()) { - graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize; - } else { - string graph_memory_manager_malloc_max_size = it->second; + size_t total_mem_size = 0; + GE_CHK_STATUS_RET_NOLOG(VarManager::GetTotalMemorySize(total_mem_size)); + GEEVENT("Total memory size is %zu", total_mem_size); + + graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio); + var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio); + + auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE); + if (it1 != options.end()) { + string graph_memory_manager_malloc_max_size = it1->second; ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); if (ret != SUCCESS) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); return ge::GE_GRAPH_OPTIONS_INVALID; } - GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); } - it = options.find(VARIABLE_MEMORY_MAX_SIZE); - if (it == options.end()) { - var_mem_max_size_ = kMemoryVarManagerMallocSize; - } else { - string memory_var_manager_malloc_size = it->second; + auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE); + if (it2 != options.end()) { + string memory_var_manager_malloc_size = it2->second; ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); if (ret != SUCCESS) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); @@ -793,6 +821,8 @@ Status VarManager::SetMemoryMallocSize(const map &options) { } } + GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_); + var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer; if (var_mem_logic_base_ > kMaxMemorySize) { REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid", diff --git a/ge/graph/manager/graph_var_manager.h b/ge/graph/manager/graph_var_manager.h index f2b68e79..a1b45959 100755 --- a/ge/graph/manager/graph_var_manager.h +++ b/ge/graph/manager/graph_var_manager.h @@ -43,6 +43,8 @@ const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; const uint64_t kSessionMemAlignSize = 512; const size_t kSessionMemAlignUnit = 2; +const double kGraphMemoryManagerMallocRatio = 26.0 / 32.0; +const double kVarMemoryManagerMallocRatio = 5.0 / 32.0; enum MemStatus { NORMAL = 0, @@ -301,6 +303,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { mutable std::recursive_mutex mutex_; Status ParseMemoryMallocSize(std::string &memory_size, size_t &my_size); + Status GetTotalMemorySize(size_t &total_mem_size); }; class VarManagerPool { diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 44115240..c89fbc42 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -60,7 +60,6 @@ const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; const char *const kForceInfershape = "_force_infershape_when_running"; const std::set kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; -const std::set kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; const std::set kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; Status SetOutputNameAttr(ComputeGraph &graph) { @@ -519,170 +518,6 @@ Status HybridModelBuilder::UpdateAnchorStatus(const NodePtr &node) { return SUCCESS; } -Status HybridModelBuilder::DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, - const InDataAnchorPtr &in_data_anchor) { - GE_CHK_GRAPH_STATUS_RET(out_data_anchor->Unlink(in_data_anchor), - "[Invoke][Unlink] failed to unlink %s:%d from %s:%d", - out_data_anchor->GetOwnerNode()->GetName().c_str(), out_data_anchor->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); - - GELOGD("Succeeded in unlinking %s:%d from %s:%d", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - out_data_anchor->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), - in_data_anchor->GetIdx()); - return SUCCESS; -} - -Status HybridModelBuilder::DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor) { - GE_CHK_GRAPH_STATUS_RET(out_data_anchor->LinkTo(in_data_anchor), "[Invoke][LinkTo]Failed to link %s:%d to %s:%d", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - out_data_anchor->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), - in_data_anchor->GetIdx()); - - GELOGD("Succeeded in linking %s:%d to %s:%d", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - out_data_anchor->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), - in_data_anchor->GetIdx()); - return SUCCESS; -} - -Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { - const auto &wrapped_node = graph.GetParentNode(); - std::set root_nodes; - for (const auto &node : graph.GetDirectNode()) { - GE_CHECK_NOTNULL(node); - if (node->GetType() != DATA_TYPE) { - if (node->GetInDataNodes().empty()) { - root_nodes.emplace(node); - } - - continue; - } - - auto data_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(data_op_desc); - - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(FAILED, "[Invoke][GetInt] failed, node:[%s] attr:[%s]", - data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); - REPORT_CALL_ERROR("E19999", "GetInt failed, node:[%s] attr:[%s]", - data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); - return FAILED; - } - - auto wrapped_node_in_anchor = wrapped_node->GetInDataAnchor(parent_index); - GE_CHECK_NOTNULL(wrapped_node_in_anchor); - auto src_out_anchor = wrapped_node_in_anchor->GetPeerOutAnchor(); - if (src_out_anchor == nullptr || src_out_anchor->GetOwnerNode() == nullptr) { - continue; - } - wrapped_node_in_anchor->UnlinkAll(); - - // link src to outputs of DataNode - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { - GE_CHECK_NOTNULL(out_data_anchor); - for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - auto dst_node = peer_in_data_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(dst_node); - const auto in_nodes = dst_node->GetInDataNodes(); - if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) { - root_nodes.emplace(dst_node); - } - GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); - GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); - } - } - } - - // transfer in control edges to all root nodes - for (auto &root_node : root_nodes) { - auto in_nodes = root_node->GetInAllNodes(); - std::set in_node_set(in_nodes.begin(), in_nodes.end()); - for (auto &in_control_node : wrapped_node->GetInControlNodes()) { - if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) { - GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); - GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); - (void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); - } - } - } - - wrapped_node->GetInControlAnchor()->UnlinkAll(); - return SUCCESS; -} - -Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { - const auto &parent_node = graph.GetParentNode(); - const NodePtr &net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); - if (net_output_node == nullptr) { - GELOGD("Graph has no netoutput no need to merge"); - return SUCCESS; - } - const auto &net_output_desc = net_output_node->GetOpDesc(); - GE_CHECK_NOTNULL(net_output_desc); - - auto all_in_nodes = net_output_node->GetInAllNodes(); - auto all_out_nodes = parent_node->GetOutAllNodes(); - net_output_node->GetInControlAnchor()->UnlinkAll(); - parent_node->GetOutControlAnchor()->UnlinkAll(); - - for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { - auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_out_anchor); - GE_CHECK_NOTNULL(src_out_anchor->GetOwnerNode()); - GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(src_out_anchor, in_data_anchor)); - - auto index = in_data_anchor->GetIdx(); - auto input_desc = net_output_desc->MutableInputDesc(index); - if (input_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "[Invoke][MutableInputDesc][%s] Failed to get input desc[%d]", - net_output_desc->GetName().c_str(), index); - REPORT_CALL_ERROR("E19999", "[%s] Failed to get input desc[%d].", net_output_desc->GetName().c_str(), index); - return INTERNAL_ERROR; - } - - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGW("SubGraph: %s NetOutput input tensor %d, attr %s not found.", - graph.GetName().c_str(), index, ATTR_NAME_PARENT_NODE_INDEX.c_str()); - continue; - } - - const OutDataAnchorPtr &parent_out_anchor = parent_node->GetOutDataAnchor(parent_index); - GE_CHECK_NOTNULL(parent_out_anchor); - for (InDataAnchorPtr &dst_in_anchor : parent_out_anchor->GetPeerInDataAnchors()) { - if (dst_in_anchor == nullptr) { - continue; - } - - GE_CHECK_NOTNULL(dst_in_anchor->GetOwnerNode()); - GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(parent_out_anchor, dst_in_anchor)); - GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, dst_in_anchor)); - } - } - - // transfer out control edges - std::set in_node_set(all_in_nodes.begin(), all_in_nodes.end()); - std::set out_node_set(all_out_nodes.begin(), all_out_nodes.end()); - for (auto &src_node : in_node_set) { - GELOGD("[%s] process in node.", src_node->GetName().c_str()); - auto out_nodes = src_node->GetOutAllNodes(); - std::set node_set(out_nodes.begin(), out_nodes.end()); - for (auto &dst_node : out_node_set) { - if (node_set.count(dst_node) == 0) { - src_node->GetOutControlAnchor()->LinkTo(dst_node->GetInControlAnchor()); - GELOGD("[%s] Restore control edge to [%s]", src_node->GetName().c_str(), dst_node->GetName().c_str()); - } - } - } - - return SUCCESS; -} - Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { merged_graph = MakeShared("MergedGraph"); merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag()); @@ -716,9 +551,21 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG } } } - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), + + const auto &filter = [](const ComputeGraphPtr &graph) { + const auto &parent_node = graph->GetParentNode(); + if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { + return false; + } + if ((parent_node->GetType() != PARTITIONEDCALL) || + (parent_node->GetOpDesc()->GetSubgraphInstanceNames().size() != 1)) { + return false; + } + return graph->GetGraphUnknownFlag(); + }; + GE_CHK_GRAPH_STATUS_RET(GraphUtils::UnfoldSubgraph(subgraph, filter), "[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph.", - subgraph->GetName().c_str()); + subgraph->GetName().c_str()) } // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. @@ -744,56 +591,6 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, - ComputeGraphPtr &parent_graph, - ComputeGraph &sub_graph) { - auto parent_node = sub_graph.GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - - GE_CHK_STATUS_RET(MergeInputNodes(sub_graph), - "[Invoke][MergeInputNodes][%s] Failed to merge data nodes for subgraph", - sub_graph.GetName().c_str()); - GE_CHK_STATUS_RET(MergeNetOutputNode(sub_graph), - "[Invoke][MergeNetOutputNode][%s] Failed to merge net output nodes for subgraph", - sub_graph.GetName().c_str()); - GELOGD("[%s] Done merging subgraph inputs and outputs successfully", sub_graph.GetName().c_str()); - - for (auto &sub_node : sub_graph.GetDirectNode()) { - auto sub_op_type = sub_node->GetType(); - if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { - continue; - } - if (sub_op_type == PARTITIONEDCALL) { - auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, kSubgraphIndex); - GE_CHECK_NOTNULL(sub_sub_graph); - if (sub_sub_graph->GetGraphUnknownFlag()) { - GE_CHK_STATUS_RET(UnfoldSubgraph(root_graph, parent_graph, *sub_sub_graph), - "[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph", - sub_sub_graph->GetName().c_str()); - continue; - } - } - - if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { - for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { - auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); - GE_CHECK_NOTNULL(sub_sub_graph); - sub_sub_graph->SetParentGraph(parent_graph); - } - } - parent_graph->AddNode(sub_node); - GELOGD("[%s::%s] added to parent graph: [%s].", - sub_graph.GetName().c_str(), - sub_node->GetName().c_str(), - parent_graph->GetName().c_str()); - sub_node->SetOwnerComputeGraph(parent_graph); - } - - GELOGD("[%s] Done merging subgraph. remove it from root graph", sub_graph.GetName().c_str()); - root_graph->RemoveSubgraph(sub_graph.GetName()); - return SUCCESS; -} - Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, const NodeItem &node_item, bool is_root_graph) { diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 3592d3d2..52d519ef 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -39,16 +39,11 @@ class HybridModelBuilder { private: static Status UpdateAnchorStatus(const NodePtr &node); - static Status DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &in_data_anchor); - static Status DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor); static NodePtr GetPeerNode(const InDataAnchorPtr &in_data_anchor); static Status GetParentNodeOutputIndex(const OpDesc &op_desc, int index, uint32_t &out_index); static Status GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, NodePtr &peer_node, int &peer_out_index); static Status HandleDtString(const GeTensor &tensor, void *var_addr); - static Status MergeInputNodes(ComputeGraph &compute_graph); - static Status MergeNetOutputNode(ComputeGraph &compute_graph); static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); - static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); static Status BuildInputMapping(GraphItem &graph_item, std::vector &data_nodes, bool is_root_graph); diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index e11e4a03..935d8a30 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -22,7 +22,6 @@ target_compile_options(atc_atc.bin PRIVATE target_compile_definitions(atc_atc.bin PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 - COMPILE_OMG_PACKAGE google=ascend_private LOG_CPP FUNC_VISIBILITY @@ -48,6 +47,7 @@ target_include_directories(atc_atc.bin PRIVATE target_link_options(atc_atc.bin PRIVATE -Wl,-Bsymbolic + -Wl,-rpath-link,${ASCEND_ATC_DIR}/stub ) target_link_libraries(atc_atc.bin PRIVATE @@ -62,8 +62,7 @@ target_link_libraries(atc_atc.bin PRIVATE parser_common gflags json - $<$>:$> - $<$:$> + runtime slog static_mmpa -lrt @@ -92,7 +91,6 @@ target_compile_options(fwk_atc.bin PRIVATE target_compile_definitions(fwk_atc.bin PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 - COMPILE_OMG_PACKAGE google=ascend_private LOG_CPP FUNC_VISIBILITY diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 53c761bd..effe0b68 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -193,6 +193,12 @@ rtError_t rtMemGetInfo(size_t *free, size_t *total) { return RT_ERROR_NONE; } +rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { + *free = 512UL * 1024UL * 1024UL; + *total = 1024UL * 1024UL * 1024UL; + return RT_ERROR_NONE; +} + rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag) { return RT_ERROR_NONE; } rtError_t rtMemFreeManaged(void *ptr) { return RT_ERROR_NONE; } diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 085b510c..9f095c6b 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -692,6 +692,7 @@ set(MULTI_PARTS_TEST_FILES "graph/manager/run_graph_unittest.cc" "graph/partition/dynamic_shape_partition_unittest.cc" "graph/manager/graph_manager_unittest.cc" + "graph/manager/graph_var_manager_unittest.cc" "graph/optimize/mem_rw_conflict_optimize_unittest.cc" "graph/optimize/graph_optimize_unittest.cc" "session/omg_omg_unittest.cc" diff --git a/tests/ut/ge/graph/load/ffts_plus_task_info_unittest.cc b/tests/ut/ge/graph/load/ffts_plus_task_info_unittest.cc index a68e5e65..57a25b54 100644 --- a/tests/ut/ge/graph/load/ffts_plus_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/ffts_plus_task_info_unittest.cc @@ -79,13 +79,12 @@ public: additionaldata1->add_context_id(5); } - void InitAicAivCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusAicAivCtxDef *ctxdef = fftsplusctxdef->mutable_aic_aiv_ctx(); + void InitAicAivCtx(domi::FftsPlusAicAivCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(1); ctxdef->set_pred_cnt_init(1); ctxdef->set_pred_cnt(1); - for (int i = 0; i < RT_CTX_SUCCESSOR_NUM; ++i) { + for (int i = 1; i < RT_CTX_SUCCESSOR_NUM; ++i) { ctxdef->add_successor_list(1); // 16 bits, len = 26 } ctxdef->set_stat(1); @@ -113,8 +112,7 @@ public: } } - void InitMixAicAivCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusMixAicAivCtxDef *ctxdef = fftsplusctxdef->mutable_mix_aic_aiv_ctx(); + void InitMixAicAivCtx(domi::FftsPlusMixAicAivCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(1); ctxdef->set_pred_cnt_init(1); @@ -153,8 +151,7 @@ public: } } - void InitSdmaCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusSdmaCtxDef *ctxdef = fftsplusctxdef->mutable_sdma_ctx(); + void InitSdmaCtx(domi::FftsPlusSdmaCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(1); ctxdef->set_pred_cnt_init(1); @@ -184,8 +181,7 @@ public: ctxdef->set_tail_data_len(1); } - void InitNotifyCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusNotifyCtxDef *ctxdef = fftsplusctxdef->mutable_notify_ctx(); + void InitNotifyCtx(domi::FftsPlusNotifyCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(1); ctxdef->set_pred_cnt_init(1); @@ -201,8 +197,7 @@ public: ctxdef->set_notify_id_base(1); } - void InitWriteValueCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusWriteValueCtxDef *ctxdef = fftsplusctxdef->mutable_write_value_ctx(); + void InitWriteValueCtx(domi::FftsPlusWriteValueCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(1); ctxdef->set_pred_cnt_init(1); @@ -227,8 +222,7 @@ public: ctxdef->add_write_value(1); } - void InitAicpuCtxCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusAicpuCtxDef *ctxdef = fftsplusctxdef->mutable_aicpu_ctx(); + void InitAicpuCtxCtx(domi::FftsPlusAicpuCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(1); ctxdef->set_pred_cnt_init(1); @@ -260,8 +254,7 @@ public: ctxdef->set_task_param_offset(32); } - void InitDataCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusDataCtxDef *ctxdef = fftsplusctxdef->mutable_data_ctx(); + void InitDataCtx(domi::FftsPlusDataCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(1); ctxdef->set_cnt_init(1); @@ -293,8 +286,7 @@ public: ctxdef->set_tail_stride_inner(1); } - void InitAtStartCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusAtStartCtxDef *ctxdef = fftsplusctxdef->mutable_at_start_ctx(); + void InitAtStartCtx(domi::FftsPlusAtStartCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(1); ctxdef->set_pred_cnt_init(1); @@ -309,8 +301,7 @@ public: ctxdef->set_thread_window_size(1); } - void InitAtEndCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusAtEndCtxDef *ctxdef = fftsplusctxdef->mutable_at_end_ctx(); + void InitAtEndCtx(domi::FftsPlusAtEndCtxDef *ctxdef) { ctxdef->set_at_start_slot_num(12); ctxdef->set_out_label_slot_num(12); ctxdef->set_aten(1); @@ -325,8 +316,7 @@ public: ctxdef->set_thread_id(1); } - void InitLabelCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusLabelCtxDef *ctxdef = fftsplusctxdef->mutable_label_ctx(); + void InitLabelCtx(domi::FftsPlusLabelCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_pred_cnt_init(1); ctxdef->set_pred_cnt(1); @@ -335,8 +325,7 @@ public: } } - void InitCaseSwitchCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusCaseSwitchCtxDef *ctxdef = fftsplusctxdef->mutable_case_switch_ctx(); + void InitCaseSwitchCtx(domi::FftsPlusCaseSwitchCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(32); ctxdef->set_start_label_id(32); @@ -366,8 +355,7 @@ public: ctxdef->set_load_addr1_offset(32); } - void InitCaseDefaultCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusCaseDefaultCtxDef *ctxdef = fftsplusctxdef->mutable_case_default_ctx(); + void InitCaseDefaultCtx(domi::FftsPlusCaseDefaultCtxDef *ctxdef) { ctxdef->set_successor_num(26); ctxdef->set_aten(32); ctxdef->set_start_label_id(1); @@ -379,8 +367,7 @@ public: } } - void InitCondSwitchCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { - domi::FftsPlusCondSwitchCtxDef *ctxdef = fftsplusctxdef->mutable_cond_switch_ctx(); + void InitCondSwitchCtx(domi::FftsPlusCondSwitchCtxDef *ctxdef) { ctxdef->set_true_successor_num(12); ctxdef->set_false_successor_num(14); ctxdef->set_aten(32); @@ -444,35 +431,38 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_software_ctx) { InitTaskSQEInfo(ffts_plus_task_def); InitTaskAdditionalDataInfo(ffts_plus_task_def); - domi::FftsPlusCtxDef *startctx = ffts_plus_task_def->add_ffts_plus_ctx(); - startctx->set_op_index(0); - startctx->set_hardware_ctx_type(0); - startctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_AT_START)); - InitAtStartCtx(startctx); + domi::FftsPlusCtxDef *fftsplusstartctx = ffts_plus_task_def->add_ffts_plus_ctx(); + fftsplusstartctx->set_op_index(0); + fftsplusstartctx->set_hardware_ctx_type(0); + fftsplusstartctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_AT_START)); + domi::FftsPlusAtStartCtxDef *startctxdef = fftsplusstartctx->mutable_at_start_ctx(); + InitAtStartCtx(startctxdef); EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); - startctx->at_start_ctx().add_successor_list(1); + startctxdef->add_successor_list(1); EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); - domi::FftsPlusCtxDef *endctx = ffts_plus_task_def->add_ffts_plus_ctx(); - endctx->set_op_index(0); - endctx->set_hardware_ctx_type(0); - endctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_AT_END)); - InitAtEndCtx(endctx); + domi::FftsPlusCtxDef *fftsplusendctx = ffts_plus_task_def->add_ffts_plus_ctx(); + fftsplusendctx->set_op_index(0); + fftsplusendctx->set_hardware_ctx_type(0); + fftsplusendctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_AT_END)); + domi::FftsPlusAtEndCtxDef *endctxdef = fftsplusendctx->mutable_at_end_ctx(); + InitAtEndCtx(endctxdef); EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); - endctx->at_end_ctx().add_succ_at_start_slot(1); + endctxdef->add_succ_at_start_slot(1); EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); - endctx->at_end_ctx().add_succ_out_label_slot(1); + endctxdef->add_succ_out_label_slot(1); EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); - domi::FftsPlusCtxDef *labelctx = ffts_plus_task_def->add_ffts_plus_ctx(); - labelctx->set_op_index(0); - labelctx->set_hardware_ctx_type(0); - labelctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_LABEL)); - InitLabelCtx(labelctx); + domi::FftsPlusCtxDef *fftspluslabelctx = ffts_plus_task_def->add_ffts_plus_ctx(); + fftspluslabelctx->set_op_index(0); + fftspluslabelctx->set_hardware_ctx_type(0); + fftspluslabelctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_LABEL)); + domi::FftsPlusLabelCtxDef *labelctxdef = fftsplusctxdef->mutable_label_ctx(); + InitLabelCtx(labelctxdef); EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); - labelctx->label_ctx().add_successor_list(1); + labelctxdef->add_successor_list(1); EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); } @@ -501,102 +491,111 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_hardware_ctx) { aicaivctx->set_op_index(0); aicaivctx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_AIV)); aicaivctx->set_software_ctx_type(0); - InitAicAivCtx(aicaivctx); + domi::FftsPlusAicAivCtxDef *aicaivdef = aicaivctx->mutable_aic_aiv_ctx(); + InitAicAivCtx(aicaivdef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - aicaivctx->aic_aiv_ctx().add_successor_list(1); + aicaivdef->add_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - aicaivctx->aic_aiv_ctx().add_kernel_name("aivtest"); + aicaivdef->add_kernel_name("aivtest"); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - aicaivctx->aic_aiv_ctx().add_src_slot(1); + aicaivdef->add_src_slot(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); domi::FftsPlusCtxDef *mixaicaivctx = ffts_plus_task_def->add_ffts_plus_ctx(); mixaicaivctx->set_op_index(0); mixaicaivctx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_MIX_AIC)); mixaicaivctx->set_software_ctx_type(0); - InitMixAicAivCtx(mixaicaivctx); + domi::FftsPlusMixAicAivCtxDef *mixctxdef = mixaicaivctx->mutable_mix_aic_aiv_ctx(); + InitMixAicAivCtx(mixctxdef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - mixaicaivctx->mix_aic_aiv_ctx().add_successor_list(1); + mixctxdef->add_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - mixaicaivctx->mix_aic_aiv_ctx().add_kernel_name("mixaiv"); + mixctxdef->add_kernel_name("mixaiv"); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - mixaicaivctx->mix_aic_aiv_ctx().add_src_slot(1); + mixctxdef->add_src_slot(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); domi::FftsPlusCtxDef *notifyctx = ffts_plus_task_def->add_ffts_plus_ctx(); notifyctx->set_op_index(0); notifyctx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_NOTIFY_WAIT)); notifyctx->set_software_ctx_type(0); - InitNotifyCtx(notifyctx); + domi::FftsPlusNotifyCtxDef *notifydef = notifyctx->mutable_notify_ctx(); + InitNotifyCtx(notifydef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - notifyctx->notify_ctx().add_successor_list(1); + notifydef->add_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); domi::FftsPlusCtxDef *sdmactx = ffts_plus_task_def->add_ffts_plus_ctx(); sdmactx->set_op_index(0); sdmactx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_SDMA)); sdmactx->set_software_ctx_type(0); - InitSdmaCtx(sdmactx); + domi::FftsPlusSdmaCtxDef *smdadef = sdmactx->mutable_sdma_ctx(); + InitSdmaCtx(smdadef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - sdmactx->sdma_ctx().add_successor_list(1); + smdadef->add_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); domi::FftsPlusCtxDef *writevalctx = ffts_plus_task_def->add_ffts_plus_ctx(); writevalctx->set_op_index(0); writevalctx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_WRITE_VALUE)); writevalctx->set_software_ctx_type(0); - InitWriteValueCtx(writevalctx); + domi::FftsPlusWriteValueCtxDef *writedef = writevalctx->mutable_write_value_ctx(); + InitWriteValueCtx(writedef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - writevalctx->write_value_ctx().add_successor_list(1); + writedef->add_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); domi::FftsPlusCtxDef *aicpuctx = ffts_plus_task_def->add_ffts_plus_ctx(); aicpuctx->set_op_index(0); aicpuctx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_AICPU)); aicpuctx->set_software_ctx_type(0); - InitAicpuCtxCtx(aicpuctx); + domi::FftsPlusAicpuCtxDef *aicpudef = aicpuctx->mutable_aicpu_ctx(); + InitAicpuCtxCtx(aicpudef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - aicpuctx->aicpu_ctx().add_successor_context_id(1); + aicpudef->add_successor_context_id(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - aicpuctx->aicpu_ctx().add_user_data(1); + aicpudef->add_user_data(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); domi::FftsPlusCtxDef *datactx = ffts_plus_task_def->add_ffts_plus_ctx(); datactx->set_op_index(0); datactx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_FLUSH_DATA)); datactx->set_software_ctx_type(0); - InitDataCtx(datactx); + domi::FftsPlusDataCtxDef *datadef = datactx->mutable_data_ctx(); + InitDataCtx(datadef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - datactx->data_ctx().add_successor_list(1); + datadef->add_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); domi::FftsPlusCtxDef *caseswitchctx = ffts_plus_task_def->add_ffts_plus_ctx(); caseswitchctx->set_op_index(0); caseswitchctx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_LOAD)); caseswitchctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_CASE_SWITCH)); - InitCaseSwitchCtx(caseswitchctx); + domi::FftsPlusCaseSwitchCtxDef *caseswitchdef = caseswitchctx->mutable_case_switch_ctx(); + InitCaseSwitchCtx(caseswitchdef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - caseswitchctx->case_switch_ctx().add_successor_list(1); + caseswitchdef->add_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); domi::FftsPlusCtxDef *candswitchctx = ffts_plus_task_def->add_ffts_plus_ctx(); candswitchctx->set_op_index(0); candswitchctx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_LOAD)); candswitchctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_COND_SWITCH)); - InitCondSwitchCtx(candswitchctx); + domi::FftsPlusCondSwitchCtxDef *candswitchdef = candswitchctx->mutable_cond_switch_ctx(); + InitCondSwitchCtx(candswitchdef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - candswitchctx->cond_switch_ctx().add_true_successor_list(1); + candswitchdef->add_true_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - candswitchctx->cond_switch_ctx().add_false_successor_list(1); + candswitchdef->add_false_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); } @@ -625,10 +624,11 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_hardware_ctx_ex) { casesdefaultctx->set_op_index(0); casesdefaultctx->set_hardware_ctx_type(static_cast(RT_HW_CTX_TYPE_LOAD)); casesdefaultctx->set_software_ctx_type(static_cast(RT_SOFT_CTX_TYPE_CASE_SWITCH)); - InitCaseDefaultCtx(casesdefaultctx); + domi::FftsPlusCaseDefaultCtxDef *casesdefaultdef = casesdefaultctx->mutable_case_default_ctx(); + InitCaseDefaultCtx(casesdefaultdef); EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); - casesdefaultctx->case_default_ctx().add_successor_list(1); + casesdefaultdef->add_successor_list(1); EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); } // test FftsPlusTaskInfo UpdateArgs diff --git a/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc new file mode 100644 index 00000000..c20e786d --- /dev/null +++ b/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#define protected public +#define private public +#include "graph/manager/graph_var_manager.h" +#include "graph/ge_context.h" +#undef protected +#undef private + +namespace ge { +class UtestGraphVarManagerTest : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestGraphVarManagerTest, test_get_total_memory_size) { + size_t total_mem_size = 0; + Status ret = VarManager::Instance(0)->GetTotalMemorySize(total_mem_size); + EXPECT_EQ(total_mem_size, 1024UL * 1024UL * 1024UL); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_no_related_option) { + const map options{}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_graph_mem_max_size) { + const map options{{"ge.graphMemoryMaxSize", "536870912"}}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_var_mem_max_size) { + const map options{{"ge.variableMemoryMaxSize", "536870912"}}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); + EXPECT_EQ(ret, SUCCESS); +} +} // namespace ge