| @@ -88,11 +88,9 @@ else () | |||||
| find_module(hccl libhccl.so ${GE_LIB_PATH}) | find_module(hccl libhccl.so ${GE_LIB_PATH}) | ||||
| find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | ||||
| find_module(runtime libruntime.so ${GE_LIB_PATH}) | find_module(runtime libruntime.so ${GE_LIB_PATH}) | ||||
| find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||||
| find_module(resource libresource.so ${GE_LIB_PATH}) | find_module(resource libresource.so ${GE_LIB_PATH}) | ||||
| find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | ||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | ||||
| #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||||
| else() | else() | ||||
| find_module(slog libalog.so ${ASCEND_ATC_DIR}) | find_module(slog libalog.so ${ASCEND_ATC_DIR}) | ||||
| find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) | find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) | ||||
| @@ -108,7 +106,6 @@ else () | |||||
| elseif(PLATFORM STREQUAL "inference") | elseif(PLATFORM STREQUAL "inference") | ||||
| find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | ||||
| find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | ||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | ||||
| if(PRODUCT STREQUAL "flr3") | if(PRODUCT STREQUAL "flr3") | ||||
| elseif(PRODUCT STREQUAL "flr1") | elseif(PRODUCT STREQUAL "flr1") | ||||
| @@ -120,10 +117,9 @@ else () | |||||
| endif() | endif() | ||||
| elseif(PLATFORM STREQUAL "all") | elseif(PLATFORM STREQUAL "all") | ||||
| find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
| find_module(runtime libruntime.so ${ASCEND_ATC_DIR}) | |||||
| find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | ||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
| find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
| find_module(ascend_hal_stub libascend_hal.so ${ASCEND_ATC_DIR}/stub) | |||||
| find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | ||||
| else() | else() | ||||
| message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | ||||
| @@ -10,12 +10,17 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") | |||||
| set(MD5 "") | |||||
| if (GE_PB_PKG) | |||||
| set(REQ_URL "${GE_PB_PKG}/libs/gflags/v2.2.2.tar.gz") | |||||
| set(MD5 "1a865b93bacfa963201af3f75b7bd64c") | |||||
| else() | else() | ||||
| set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") | |||||
| set(MD5 "") | |||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") | |||||
| set(MD5 "") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") | |||||
| set(MD5 "1a865b93bacfa963201af3f75b7bd64c") | |||||
| endif () | |||||
| endif () | endif () | ||||
| set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private") | set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private") | ||||
| @@ -112,6 +112,8 @@ set(EXECUTOR_SRC_LIST | |||||
| "common/dump/dump_op.cc" | "common/dump/dump_op.cc" | ||||
| "common/dump/exception_dumper.cc" | "common/dump/exception_dumper.cc" | ||||
| "common/dump/opdebug_register.cc" | "common/dump/opdebug_register.cc" | ||||
| "common/ge/op_tiling_manager.cc" | |||||
| "common/ge/plugin_manager.cc" | |||||
| "common/profiling/ge_profiling.cc" | "common/profiling/ge_profiling.cc" | ||||
| "common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
| "executor/ge_executor.cc" | "executor/ge_executor.cc" | ||||
| @@ -259,6 +261,8 @@ set(EXECUTOR_SRC_LIST | |||||
| set(COMPILER_SRC_LIST | set(COMPILER_SRC_LIST | ||||
| "analyzer/analyzer.cc" | "analyzer/analyzer.cc" | ||||
| "common/dump/dump_op.cc" | "common/dump/dump_op.cc" | ||||
| "common/ge/op_tiling_manager.cc" | |||||
| "common/ge/plugin_manager.cc" | |||||
| "common/helper/model_cache_helper.cc" | "common/helper/model_cache_helper.cc" | ||||
| "common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
| "engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
| @@ -619,7 +623,6 @@ target_compile_definitions(ge_compiler PRIVATE | |||||
| REUSE_MEMORY=1 | REUSE_MEMORY=1 | ||||
| FMK_SUPPORT_DUMP | FMK_SUPPORT_DUMP | ||||
| FMK_HOST_INFER | FMK_HOST_INFER | ||||
| COMPILE_OMG_PACKAGE | |||||
| google=ascend_private | google=ascend_private | ||||
| FUNC_VISIBILITY | FUNC_VISIBILITY | ||||
| $<$<STREQUAL:${ENABLE_OPEN_SRC},True>:ONLY_COMPILE_OPEN_SRC> | $<$<STREQUAL:${ENABLE_OPEN_SRC},True>:ONLY_COMPILE_OPEN_SRC> | ||||
| @@ -681,8 +684,7 @@ target_link_libraries(ge_compiler PRIVATE | |||||
| c_sec | c_sec | ||||
| error_manager | error_manager | ||||
| slog | slog | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime>> | |||||
| $<$<BOOL:${ENABLE_OPEN_SRC}>:$<BUILD_INTERFACE:runtime_compile>> | |||||
| runtime | |||||
| opt_feature | opt_feature | ||||
| -Wl,--as-needed | -Wl,--as-needed | ||||
| json | json | ||||
| @@ -350,7 +350,7 @@ Status FftsPlusTaskInfo::InitAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def | |||||
| i_cache_prefetch_cnt_2)); | i_cache_prefetch_cnt_2)); | ||||
| ctx->tailTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_task_start_pc) & 0XFFFFFFFF); | ctx->tailTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_task_start_pc) & 0XFFFFFFFF); | ||||
| ctx->tailTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_task_start_pc) >> 32) & 0X0000FFFF); | ctx->tailTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_task_start_pc) >> 32) & 0X0000FFFF); | ||||
| uint32_t i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); | |||||
| uint32_t i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); | |||||
| ctx->icachePrefetchCnt = static_cast<uint16_t>(i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 | ctx->icachePrefetchCnt = static_cast<uint16_t>(i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 | ||||
| if (ctx_def.src_slot_size() != kSrcSlotNum) { | if (ctx_def.src_slot_size() != kSrcSlotNum) { | ||||
| @@ -526,8 +526,7 @@ Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &c | |||||
| ctx->tailAicTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_aic_task_start_pc) & 0XFFFFFFFF); | ctx->tailAicTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_aic_task_start_pc) & 0XFFFFFFFF); | ||||
| ctx->tailAicTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_aic_task_start_pc) >> 32) & | ctx->tailAicTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_aic_task_start_pc) >> 32) & | ||||
| 0X0000FFFF); | 0X0000FFFF); | ||||
| uint32_t aic_i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); | |||||
| // TODO | |||||
| uint32_t aic_i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); | |||||
| ctx->icachePrefetchCnt = static_cast<uint16_t>(aic_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 | ctx->icachePrefetchCnt = static_cast<uint16_t>(aic_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 | ||||
| uint32_t i_cache_prefetch_cnt_3; | uint32_t i_cache_prefetch_cnt_3; | ||||
| @@ -545,9 +544,10 @@ Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &c | |||||
| ctx->tailAivTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_aiv_task_start_pc) & 0XFFFFFFFF); | ctx->tailAivTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_aiv_task_start_pc) & 0XFFFFFFFF); | ||||
| ctx->tailAivTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_aiv_task_start_pc) >> 32) & | ctx->tailAivTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_aiv_task_start_pc) >> 32) & | ||||
| 0X0000FFFF); | 0X0000FFFF); | ||||
| uint32_t aiv_i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_3, i_cache_prefetch_cnt_4); | |||||
| uint32_t aiv_i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_3, i_cache_prefetch_cnt_4); | |||||
| // TODO | // TODO | ||||
| ctx->icachePrefetchCnt = static_cast<uint16_t>(aiv_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 | |||||
| ctx->icachePrefetchCnt = static_cast<uint16_t>( | |||||
| std::min(aic_i_cache_prefetch_cnt, aiv_i_cache_prefetch_cnt) & 0X0000001F); // 5 bits, 0001,1111 | |||||
| if (ctx_def.src_slot_size() != kSrcSlotNum) { | if (ctx_def.src_slot_size() != kSrcSlotNum) { | ||||
| REPORT_INNER_ERROR("E19999", "Size of src_slot in FftsPlusMixAicAivCtxDef should be %d, but %d exactly", | REPORT_INNER_ERROR("E19999", "Size of src_slot in FftsPlusMixAicAivCtxDef should be %d, but %d exactly", | ||||
| @@ -913,11 +913,11 @@ void FftsPlusTaskInfo::SetAdditionalDatatoCtx(const domi::FftsPlusTaskDef &task_ | |||||
| Status FftsPlusTaskInfo::UpdateMixAicAivCtxParam(const domi::FftsPlusMixAicAivCtxDef &ctx_def, size_t ctx_idx) { | Status FftsPlusTaskInfo::UpdateMixAicAivCtxParam(const domi::FftsPlusMixAicAivCtxDef &ctx_def, size_t ctx_idx) { | ||||
| if (ctx_additional_data_.count(ctx_idx) == 0) { | if (ctx_additional_data_.count(ctx_idx) == 0) { | ||||
| GELOGD("ctx idx:%d not in ctx additional data"); | |||||
| GELOGD("ctx idx:%zu not in ctx additional data"); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (ctx_additional_data_[ctx_idx].count(kModeInArgsFirstField) == 0) { | if (ctx_additional_data_[ctx_idx].count(kModeInArgsFirstField) == 0) { | ||||
| GELOGD("ctx idx:%d need not to save mode in args first field"); | |||||
| GELOGD("ctx idx:%zu need not to save mode in args first field"); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (rtApp_addr_ == 0) { | if (rtApp_addr_ == 0) { | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "graph/manager/graph_mem_manager.h" | #include "graph/manager/graph_mem_manager.h" | ||||
| #include "graph/manager/trans_var_data_utils.h" | #include "graph/manager/trans_var_data_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/ge_context.h" | |||||
| using std::map; | using std::map; | ||||
| using std::string; | using std::string; | ||||
| @@ -767,25 +768,52 @@ Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &grap | |||||
| return var_resource_->GetChangedGraphId(var_name, graph_id); | return var_resource_->GetChangedGraphId(var_name, graph_id); | ||||
| } | } | ||||
| Status VarManager::GetTotalMemorySize(size_t &total_mem_size) { | |||||
| rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", | |||||
| GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| size_t free_mem = 0; | |||||
| rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free_mem, &total_mem_size); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | |||||
| GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | ||||
| auto it = options.find(GRAPH_MEMORY_MAX_SIZE); | |||||
| if (it == options.end()) { | |||||
| graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize; | |||||
| } else { | |||||
| string graph_memory_manager_malloc_max_size = it->second; | |||||
| size_t total_mem_size = 0; | |||||
| GE_CHK_STATUS_RET_NOLOG(VarManager::GetTotalMemorySize(total_mem_size)); | |||||
| GEEVENT("Total memory size is %zu", total_mem_size); | |||||
| graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio); | |||||
| var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio); | |||||
| auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE); | |||||
| if (it1 != options.end()) { | |||||
| string graph_memory_manager_malloc_max_size = it1->second; | |||||
| ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); | ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); | GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); | ||||
| return ge::GE_GRAPH_OPTIONS_INVALID; | return ge::GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); | |||||
| } | } | ||||
| it = options.find(VARIABLE_MEMORY_MAX_SIZE); | |||||
| if (it == options.end()) { | |||||
| var_mem_max_size_ = kMemoryVarManagerMallocSize; | |||||
| } else { | |||||
| string memory_var_manager_malloc_size = it->second; | |||||
| auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE); | |||||
| if (it2 != options.end()) { | |||||
| string memory_var_manager_malloc_size = it2->second; | |||||
| ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); | ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); | GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); | ||||
| @@ -793,6 +821,8 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||||
| } | } | ||||
| } | } | ||||
| GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_); | |||||
| var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer; | var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer; | ||||
| if (var_mem_logic_base_ > kMaxMemorySize) { | if (var_mem_logic_base_ > kMaxMemorySize) { | ||||
| REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid", | REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid", | ||||
| @@ -43,6 +43,8 @@ const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; | |||||
| const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; | const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; | ||||
| const uint64_t kSessionMemAlignSize = 512; | const uint64_t kSessionMemAlignSize = 512; | ||||
| const size_t kSessionMemAlignUnit = 2; | const size_t kSessionMemAlignUnit = 2; | ||||
| const double kGraphMemoryManagerMallocRatio = 26.0 / 32.0; | |||||
| const double kVarMemoryManagerMallocRatio = 5.0 / 32.0; | |||||
| enum MemStatus { | enum MemStatus { | ||||
| NORMAL = 0, | NORMAL = 0, | ||||
| @@ -301,6 +303,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| mutable std::recursive_mutex mutex_; | mutable std::recursive_mutex mutex_; | ||||
| Status ParseMemoryMallocSize(std::string &memory_size, size_t &my_size); | Status ParseMemoryMallocSize(std::string &memory_size, size_t &my_size); | ||||
| Status GetTotalMemorySize(size_t &total_mem_size); | |||||
| }; | }; | ||||
| class VarManagerPool { | class VarManagerPool { | ||||
| @@ -60,7 +60,6 @@ const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | |||||
| const char *const kForceInfershape = "_force_infershape_when_running"; | const char *const kForceInfershape = "_force_infershape_when_running"; | ||||
| const std::set<std::string> kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; | const std::set<std::string> kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; | ||||
| const std::set<std::string> kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; | |||||
| const std::set<std::string> kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | const std::set<std::string> kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | ||||
| Status SetOutputNameAttr(ComputeGraph &graph) { | Status SetOutputNameAttr(ComputeGraph &graph) { | ||||
| @@ -519,170 +518,6 @@ Status HybridModelBuilder::UpdateAnchorStatus(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, | |||||
| const InDataAnchorPtr &in_data_anchor) { | |||||
| GE_CHK_GRAPH_STATUS_RET(out_data_anchor->Unlink(in_data_anchor), | |||||
| "[Invoke][Unlink] failed to unlink %s:%d from %s:%d", | |||||
| out_data_anchor->GetOwnerNode()->GetName().c_str(), out_data_anchor->GetIdx(), | |||||
| in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
| GELOGD("Succeeded in unlinking %s:%d from %s:%d", | |||||
| out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| out_data_anchor->GetIdx(), | |||||
| in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| in_data_anchor->GetIdx()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor) { | |||||
| GE_CHK_GRAPH_STATUS_RET(out_data_anchor->LinkTo(in_data_anchor), "[Invoke][LinkTo]Failed to link %s:%d to %s:%d", | |||||
| out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| out_data_anchor->GetIdx(), | |||||
| in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| in_data_anchor->GetIdx()); | |||||
| GELOGD("Succeeded in linking %s:%d to %s:%d", | |||||
| out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| out_data_anchor->GetIdx(), | |||||
| in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| in_data_anchor->GetIdx()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { | |||||
| const auto &wrapped_node = graph.GetParentNode(); | |||||
| std::set<NodePtr> root_nodes; | |||||
| for (const auto &node : graph.GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (node->GetType() != DATA_TYPE) { | |||||
| if (node->GetInDataNodes().empty()) { | |||||
| root_nodes.emplace(node); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| auto data_op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(data_op_desc); | |||||
| uint32_t parent_index = 0; | |||||
| if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
| GELOGE(FAILED, "[Invoke][GetInt] failed, node:[%s] attr:[%s]", | |||||
| data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "GetInt failed, node:[%s] attr:[%s]", | |||||
| data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto wrapped_node_in_anchor = wrapped_node->GetInDataAnchor(parent_index); | |||||
| GE_CHECK_NOTNULL(wrapped_node_in_anchor); | |||||
| auto src_out_anchor = wrapped_node_in_anchor->GetPeerOutAnchor(); | |||||
| if (src_out_anchor == nullptr || src_out_anchor->GetOwnerNode() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| wrapped_node_in_anchor->UnlinkAll(); | |||||
| // link src to outputs of DataNode | |||||
| for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
| GE_CHECK_NOTNULL(out_data_anchor); | |||||
| for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| auto dst_node = peer_in_data_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| const auto in_nodes = dst_node->GetInDataNodes(); | |||||
| if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) { | |||||
| root_nodes.emplace(dst_node); | |||||
| } | |||||
| GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); | |||||
| GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); | |||||
| } | |||||
| } | |||||
| } | |||||
| // transfer in control edges to all root nodes | |||||
| for (auto &root_node : root_nodes) { | |||||
| auto in_nodes = root_node->GetInAllNodes(); | |||||
| std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end()); | |||||
| for (auto &in_control_node : wrapped_node->GetInControlNodes()) { | |||||
| if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) { | |||||
| GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); | |||||
| GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); | |||||
| (void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); | |||||
| } | |||||
| } | |||||
| } | |||||
| wrapped_node->GetInControlAnchor()->UnlinkAll(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { | |||||
| const auto &parent_node = graph.GetParentNode(); | |||||
| const NodePtr &net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); | |||||
| if (net_output_node == nullptr) { | |||||
| GELOGD("Graph has no netoutput no need to merge"); | |||||
| return SUCCESS; | |||||
| } | |||||
| const auto &net_output_desc = net_output_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(net_output_desc); | |||||
| auto all_in_nodes = net_output_node->GetInAllNodes(); | |||||
| auto all_out_nodes = parent_node->GetOutAllNodes(); | |||||
| net_output_node->GetInControlAnchor()->UnlinkAll(); | |||||
| parent_node->GetOutControlAnchor()->UnlinkAll(); | |||||
| for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||||
| auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(src_out_anchor); | |||||
| GE_CHECK_NOTNULL(src_out_anchor->GetOwnerNode()); | |||||
| GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(src_out_anchor, in_data_anchor)); | |||||
| auto index = in_data_anchor->GetIdx(); | |||||
| auto input_desc = net_output_desc->MutableInputDesc(index); | |||||
| if (input_desc == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[Invoke][MutableInputDesc][%s] Failed to get input desc[%d]", | |||||
| net_output_desc->GetName().c_str(), index); | |||||
| REPORT_CALL_ERROR("E19999", "[%s] Failed to get input desc[%d].", net_output_desc->GetName().c_str(), index); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| uint32_t parent_index = 0; | |||||
| if (!AttrUtils::GetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
| GELOGW("SubGraph: %s NetOutput input tensor %d, attr %s not found.", | |||||
| graph.GetName().c_str(), index, ATTR_NAME_PARENT_NODE_INDEX.c_str()); | |||||
| continue; | |||||
| } | |||||
| const OutDataAnchorPtr &parent_out_anchor = parent_node->GetOutDataAnchor(parent_index); | |||||
| GE_CHECK_NOTNULL(parent_out_anchor); | |||||
| for (InDataAnchorPtr &dst_in_anchor : parent_out_anchor->GetPeerInDataAnchors()) { | |||||
| if (dst_in_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| GE_CHECK_NOTNULL(dst_in_anchor->GetOwnerNode()); | |||||
| GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(parent_out_anchor, dst_in_anchor)); | |||||
| GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, dst_in_anchor)); | |||||
| } | |||||
| } | |||||
| // transfer out control edges | |||||
| std::set<NodePtr> in_node_set(all_in_nodes.begin(), all_in_nodes.end()); | |||||
| std::set<NodePtr> out_node_set(all_out_nodes.begin(), all_out_nodes.end()); | |||||
| for (auto &src_node : in_node_set) { | |||||
| GELOGD("[%s] process in node.", src_node->GetName().c_str()); | |||||
| auto out_nodes = src_node->GetOutAllNodes(); | |||||
| std::set<NodePtr> node_set(out_nodes.begin(), out_nodes.end()); | |||||
| for (auto &dst_node : out_node_set) { | |||||
| if (node_set.count(dst_node) == 0) { | |||||
| src_node->GetOutControlAnchor()->LinkTo(dst_node->GetInControlAnchor()); | |||||
| GELOGD("[%s] Restore control edge to [%s]", src_node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { | Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { | ||||
| merged_graph = MakeShared<ComputeGraph>("MergedGraph"); | merged_graph = MakeShared<ComputeGraph>("MergedGraph"); | ||||
| merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag()); | merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag()); | ||||
| @@ -716,9 +551,21 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), | |||||
| const auto &filter = [](const ComputeGraphPtr &graph) { | |||||
| const auto &parent_node = graph->GetParentNode(); | |||||
| if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if ((parent_node->GetType() != PARTITIONEDCALL) || | |||||
| (parent_node->GetOpDesc()->GetSubgraphInstanceNames().size() != 1)) { | |||||
| return false; | |||||
| } | |||||
| return graph->GetGraphUnknownFlag(); | |||||
| }; | |||||
| GE_CHK_GRAPH_STATUS_RET(GraphUtils::UnfoldSubgraph(subgraph, filter), | |||||
| "[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph.", | "[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph.", | ||||
| subgraph->GetName().c_str()); | |||||
| subgraph->GetName().c_str()) | |||||
| } | } | ||||
| // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. | // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. | ||||
| @@ -744,56 +591,6 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, | |||||
| ComputeGraphPtr &parent_graph, | |||||
| ComputeGraph &sub_graph) { | |||||
| auto parent_node = sub_graph.GetParentNode(); | |||||
| GE_CHECK_NOTNULL(parent_node); | |||||
| GE_CHK_STATUS_RET(MergeInputNodes(sub_graph), | |||||
| "[Invoke][MergeInputNodes][%s] Failed to merge data nodes for subgraph", | |||||
| sub_graph.GetName().c_str()); | |||||
| GE_CHK_STATUS_RET(MergeNetOutputNode(sub_graph), | |||||
| "[Invoke][MergeNetOutputNode][%s] Failed to merge net output nodes for subgraph", | |||||
| sub_graph.GetName().c_str()); | |||||
| GELOGD("[%s] Done merging subgraph inputs and outputs successfully", sub_graph.GetName().c_str()); | |||||
| for (auto &sub_node : sub_graph.GetDirectNode()) { | |||||
| auto sub_op_type = sub_node->GetType(); | |||||
| if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { | |||||
| continue; | |||||
| } | |||||
| if (sub_op_type == PARTITIONEDCALL) { | |||||
| auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, kSubgraphIndex); | |||||
| GE_CHECK_NOTNULL(sub_sub_graph); | |||||
| if (sub_sub_graph->GetGraphUnknownFlag()) { | |||||
| GE_CHK_STATUS_RET(UnfoldSubgraph(root_graph, parent_graph, *sub_sub_graph), | |||||
| "[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph", | |||||
| sub_sub_graph->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { | |||||
| for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { | |||||
| auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); | |||||
| GE_CHECK_NOTNULL(sub_sub_graph); | |||||
| sub_sub_graph->SetParentGraph(parent_graph); | |||||
| } | |||||
| } | |||||
| parent_graph->AddNode(sub_node); | |||||
| GELOGD("[%s::%s] added to parent graph: [%s].", | |||||
| sub_graph.GetName().c_str(), | |||||
| sub_node->GetName().c_str(), | |||||
| parent_graph->GetName().c_str()); | |||||
| sub_node->SetOwnerComputeGraph(parent_graph); | |||||
| } | |||||
| GELOGD("[%s] Done merging subgraph. remove it from root graph", sub_graph.GetName().c_str()); | |||||
| root_graph->RemoveSubgraph(sub_graph.GetName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, | Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, | ||||
| const NodeItem &node_item, | const NodeItem &node_item, | ||||
| bool is_root_graph) { | bool is_root_graph) { | ||||
| @@ -39,16 +39,11 @@ class HybridModelBuilder { | |||||
| private: | private: | ||||
| static Status UpdateAnchorStatus(const NodePtr &node); | static Status UpdateAnchorStatus(const NodePtr &node); | ||||
| static Status DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &in_data_anchor); | |||||
| static Status DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor); | |||||
| static NodePtr GetPeerNode(const InDataAnchorPtr &in_data_anchor); | static NodePtr GetPeerNode(const InDataAnchorPtr &in_data_anchor); | ||||
| static Status GetParentNodeOutputIndex(const OpDesc &op_desc, int index, uint32_t &out_index); | static Status GetParentNodeOutputIndex(const OpDesc &op_desc, int index, uint32_t &out_index); | ||||
| static Status GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, NodePtr &peer_node, int &peer_out_index); | static Status GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, NodePtr &peer_node, int &peer_out_index); | ||||
| static Status HandleDtString(const GeTensor &tensor, void *var_addr); | static Status HandleDtString(const GeTensor &tensor, void *var_addr); | ||||
| static Status MergeInputNodes(ComputeGraph &compute_graph); | |||||
| static Status MergeNetOutputNode(ComputeGraph &compute_graph); | |||||
| static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); | static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); | ||||
| static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); | |||||
| static Status BuildInputMapping(GraphItem &graph_item, | static Status BuildInputMapping(GraphItem &graph_item, | ||||
| std::vector<NodeItem *> &data_nodes, | std::vector<NodeItem *> &data_nodes, | ||||
| bool is_root_graph); | bool is_root_graph); | ||||
| @@ -22,7 +22,6 @@ target_compile_options(atc_atc.bin PRIVATE | |||||
| target_compile_definitions(atc_atc.bin PRIVATE | target_compile_definitions(atc_atc.bin PRIVATE | ||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
| COMPILE_OMG_PACKAGE | |||||
| google=ascend_private | google=ascend_private | ||||
| LOG_CPP | LOG_CPP | ||||
| FUNC_VISIBILITY | FUNC_VISIBILITY | ||||
| @@ -48,6 +47,7 @@ target_include_directories(atc_atc.bin PRIVATE | |||||
| target_link_options(atc_atc.bin PRIVATE | target_link_options(atc_atc.bin PRIVATE | ||||
| -Wl,-Bsymbolic | -Wl,-Bsymbolic | ||||
| -Wl,-rpath-link,${ASCEND_ATC_DIR}/stub | |||||
| ) | ) | ||||
| target_link_libraries(atc_atc.bin PRIVATE | target_link_libraries(atc_atc.bin PRIVATE | ||||
| @@ -62,8 +62,7 @@ target_link_libraries(atc_atc.bin PRIVATE | |||||
| parser_common | parser_common | ||||
| gflags | gflags | ||||
| json | json | ||||
| $<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime>> | |||||
| $<$<BOOL:${ENABLE_OPEN_SRC}>:$<BUILD_INTERFACE:runtime_compile>> | |||||
| runtime | |||||
| slog | slog | ||||
| static_mmpa | static_mmpa | ||||
| -lrt | -lrt | ||||
| @@ -92,7 +91,6 @@ target_compile_options(fwk_atc.bin PRIVATE | |||||
| target_compile_definitions(fwk_atc.bin PRIVATE | target_compile_definitions(fwk_atc.bin PRIVATE | ||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
| COMPILE_OMG_PACKAGE | |||||
| google=ascend_private | google=ascend_private | ||||
| LOG_CPP | LOG_CPP | ||||
| FUNC_VISIBILITY | FUNC_VISIBILITY | ||||
| @@ -193,6 +193,12 @@ rtError_t rtMemGetInfo(size_t *free, size_t *total) { | |||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { | |||||
| *free = 512UL * 1024UL * 1024UL; | |||||
| *total = 1024UL * 1024UL * 1024UL; | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag) { return RT_ERROR_NONE; } | rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag) { return RT_ERROR_NONE; } | ||||
| rtError_t rtMemFreeManaged(void *ptr) { return RT_ERROR_NONE; } | rtError_t rtMemFreeManaged(void *ptr) { return RT_ERROR_NONE; } | ||||
| @@ -692,6 +692,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
| "graph/manager/run_graph_unittest.cc" | "graph/manager/run_graph_unittest.cc" | ||||
| "graph/partition/dynamic_shape_partition_unittest.cc" | "graph/partition/dynamic_shape_partition_unittest.cc" | ||||
| "graph/manager/graph_manager_unittest.cc" | "graph/manager/graph_manager_unittest.cc" | ||||
| "graph/manager/graph_var_manager_unittest.cc" | |||||
| "graph/optimize/mem_rw_conflict_optimize_unittest.cc" | "graph/optimize/mem_rw_conflict_optimize_unittest.cc" | ||||
| "graph/optimize/graph_optimize_unittest.cc" | "graph/optimize/graph_optimize_unittest.cc" | ||||
| "session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
| @@ -79,13 +79,12 @@ public: | |||||
| additionaldata1->add_context_id(5); | additionaldata1->add_context_id(5); | ||||
| } | } | ||||
| void InitAicAivCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusAicAivCtxDef *ctxdef = fftsplusctxdef->mutable_aic_aiv_ctx(); | |||||
| void InitAicAivCtx(domi::FftsPlusAicAivCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| ctxdef->set_pred_cnt_init(1); | ctxdef->set_pred_cnt_init(1); | ||||
| ctxdef->set_pred_cnt(1); | ctxdef->set_pred_cnt(1); | ||||
| for (int i = 0; i < RT_CTX_SUCCESSOR_NUM; ++i) { | |||||
| for (int i = 1; i < RT_CTX_SUCCESSOR_NUM; ++i) { | |||||
| ctxdef->add_successor_list(1); // 16 bits, len = 26 | ctxdef->add_successor_list(1); // 16 bits, len = 26 | ||||
| } | } | ||||
| ctxdef->set_stat(1); | ctxdef->set_stat(1); | ||||
| @@ -113,8 +112,7 @@ public: | |||||
| } | } | ||||
| } | } | ||||
| void InitMixAicAivCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusMixAicAivCtxDef *ctxdef = fftsplusctxdef->mutable_mix_aic_aiv_ctx(); | |||||
| void InitMixAicAivCtx(domi::FftsPlusMixAicAivCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| ctxdef->set_pred_cnt_init(1); | ctxdef->set_pred_cnt_init(1); | ||||
| @@ -153,8 +151,7 @@ public: | |||||
| } | } | ||||
| } | } | ||||
| void InitSdmaCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusSdmaCtxDef *ctxdef = fftsplusctxdef->mutable_sdma_ctx(); | |||||
| void InitSdmaCtx(domi::FftsPlusSdmaCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| ctxdef->set_pred_cnt_init(1); | ctxdef->set_pred_cnt_init(1); | ||||
| @@ -184,8 +181,7 @@ public: | |||||
| ctxdef->set_tail_data_len(1); | ctxdef->set_tail_data_len(1); | ||||
| } | } | ||||
| void InitNotifyCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusNotifyCtxDef *ctxdef = fftsplusctxdef->mutable_notify_ctx(); | |||||
| void InitNotifyCtx(domi::FftsPlusNotifyCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| ctxdef->set_pred_cnt_init(1); | ctxdef->set_pred_cnt_init(1); | ||||
| @@ -201,8 +197,7 @@ public: | |||||
| ctxdef->set_notify_id_base(1); | ctxdef->set_notify_id_base(1); | ||||
| } | } | ||||
| void InitWriteValueCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusWriteValueCtxDef *ctxdef = fftsplusctxdef->mutable_write_value_ctx(); | |||||
| void InitWriteValueCtx(domi::FftsPlusWriteValueCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| ctxdef->set_pred_cnt_init(1); | ctxdef->set_pred_cnt_init(1); | ||||
| @@ -227,8 +222,7 @@ public: | |||||
| ctxdef->add_write_value(1); | ctxdef->add_write_value(1); | ||||
| } | } | ||||
| void InitAicpuCtxCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusAicpuCtxDef *ctxdef = fftsplusctxdef->mutable_aicpu_ctx(); | |||||
| void InitAicpuCtxCtx(domi::FftsPlusAicpuCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| ctxdef->set_pred_cnt_init(1); | ctxdef->set_pred_cnt_init(1); | ||||
| @@ -260,8 +254,7 @@ public: | |||||
| ctxdef->set_task_param_offset(32); | ctxdef->set_task_param_offset(32); | ||||
| } | } | ||||
| void InitDataCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusDataCtxDef *ctxdef = fftsplusctxdef->mutable_data_ctx(); | |||||
| void InitDataCtx(domi::FftsPlusDataCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| ctxdef->set_cnt_init(1); | ctxdef->set_cnt_init(1); | ||||
| @@ -293,8 +286,7 @@ public: | |||||
| ctxdef->set_tail_stride_inner(1); | ctxdef->set_tail_stride_inner(1); | ||||
| } | } | ||||
| void InitAtStartCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusAtStartCtxDef *ctxdef = fftsplusctxdef->mutable_at_start_ctx(); | |||||
| void InitAtStartCtx(domi::FftsPlusAtStartCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| ctxdef->set_pred_cnt_init(1); | ctxdef->set_pred_cnt_init(1); | ||||
| @@ -309,8 +301,7 @@ public: | |||||
| ctxdef->set_thread_window_size(1); | ctxdef->set_thread_window_size(1); | ||||
| } | } | ||||
| void InitAtEndCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusAtEndCtxDef *ctxdef = fftsplusctxdef->mutable_at_end_ctx(); | |||||
| void InitAtEndCtx(domi::FftsPlusAtEndCtxDef *ctxdef) { | |||||
| ctxdef->set_at_start_slot_num(12); | ctxdef->set_at_start_slot_num(12); | ||||
| ctxdef->set_out_label_slot_num(12); | ctxdef->set_out_label_slot_num(12); | ||||
| ctxdef->set_aten(1); | ctxdef->set_aten(1); | ||||
| @@ -325,8 +316,7 @@ public: | |||||
| ctxdef->set_thread_id(1); | ctxdef->set_thread_id(1); | ||||
| } | } | ||||
| void InitLabelCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusLabelCtxDef *ctxdef = fftsplusctxdef->mutable_label_ctx(); | |||||
| void InitLabelCtx(domi::FftsPlusLabelCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_pred_cnt_init(1); | ctxdef->set_pred_cnt_init(1); | ||||
| ctxdef->set_pred_cnt(1); | ctxdef->set_pred_cnt(1); | ||||
| @@ -335,8 +325,7 @@ public: | |||||
| } | } | ||||
| } | } | ||||
| void InitCaseSwitchCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusCaseSwitchCtxDef *ctxdef = fftsplusctxdef->mutable_case_switch_ctx(); | |||||
| void InitCaseSwitchCtx(domi::FftsPlusCaseSwitchCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(32); | ctxdef->set_aten(32); | ||||
| ctxdef->set_start_label_id(32); | ctxdef->set_start_label_id(32); | ||||
| @@ -366,8 +355,7 @@ public: | |||||
| ctxdef->set_load_addr1_offset(32); | ctxdef->set_load_addr1_offset(32); | ||||
| } | } | ||||
| void InitCaseDefaultCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusCaseDefaultCtxDef *ctxdef = fftsplusctxdef->mutable_case_default_ctx(); | |||||
| void InitCaseDefaultCtx(domi::FftsPlusCaseDefaultCtxDef *ctxdef) { | |||||
| ctxdef->set_successor_num(26); | ctxdef->set_successor_num(26); | ||||
| ctxdef->set_aten(32); | ctxdef->set_aten(32); | ||||
| ctxdef->set_start_label_id(1); | ctxdef->set_start_label_id(1); | ||||
| @@ -379,8 +367,7 @@ public: | |||||
| } | } | ||||
| } | } | ||||
| void InitCondSwitchCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||||
| domi::FftsPlusCondSwitchCtxDef *ctxdef = fftsplusctxdef->mutable_cond_switch_ctx(); | |||||
| void InitCondSwitchCtx(domi::FftsPlusCondSwitchCtxDef *ctxdef) { | |||||
| ctxdef->set_true_successor_num(12); | ctxdef->set_true_successor_num(12); | ||||
| ctxdef->set_false_successor_num(14); | ctxdef->set_false_successor_num(14); | ||||
| ctxdef->set_aten(32); | ctxdef->set_aten(32); | ||||
| @@ -444,35 +431,38 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_software_ctx) { | |||||
| InitTaskSQEInfo(ffts_plus_task_def); | InitTaskSQEInfo(ffts_plus_task_def); | ||||
| InitTaskAdditionalDataInfo(ffts_plus_task_def); | InitTaskAdditionalDataInfo(ffts_plus_task_def); | ||||
| domi::FftsPlusCtxDef *startctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||||
| startctx->set_op_index(0); | |||||
| startctx->set_hardware_ctx_type(0); | |||||
| startctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_START)); | |||||
| InitAtStartCtx(startctx); | |||||
| domi::FftsPlusCtxDef *fftsplusstartctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||||
| fftsplusstartctx->set_op_index(0); | |||||
| fftsplusstartctx->set_hardware_ctx_type(0); | |||||
| fftsplusstartctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_START)); | |||||
| domi::FftsPlusAtStartCtxDef *startctxdef = fftsplusstartctx->mutable_at_start_ctx(); | |||||
| InitAtStartCtx(startctxdef); | |||||
| EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | ||||
| startctx->at_start_ctx().add_successor_list(1); | |||||
| startctxdef->add_successor_list(1); | |||||
| EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *endctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||||
| endctx->set_op_index(0); | |||||
| endctx->set_hardware_ctx_type(0); | |||||
| endctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_END)); | |||||
| InitAtEndCtx(endctx); | |||||
| domi::FftsPlusCtxDef *fftsplusendctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||||
| fftsplusendctx->set_op_index(0); | |||||
| fftsplusendctx->set_hardware_ctx_type(0); | |||||
| fftsplusendctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_END)); | |||||
| domi::FftsPlusAtEndCtxDef *endctxdef = fftsplusendctx->mutable_at_end_ctx(); | |||||
| InitAtEndCtx(endctxdef); | |||||
| EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | ||||
| endctx->at_end_ctx().add_succ_at_start_slot(1); | |||||
| endctxdef->add_succ_at_start_slot(1); | |||||
| EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | ||||
| endctx->at_end_ctx().add_succ_out_label_slot(1); | |||||
| endctxdef->add_succ_out_label_slot(1); | |||||
| EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *labelctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||||
| labelctx->set_op_index(0); | |||||
| labelctx->set_hardware_ctx_type(0); | |||||
| labelctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_LABEL)); | |||||
| InitLabelCtx(labelctx); | |||||
| domi::FftsPlusCtxDef *fftspluslabelctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||||
| fftspluslabelctx->set_op_index(0); | |||||
| fftspluslabelctx->set_hardware_ctx_type(0); | |||||
| fftspluslabelctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_LABEL)); | |||||
| domi::FftsPlusLabelCtxDef *labelctxdef = fftsplusctxdef->mutable_label_ctx(); | |||||
| InitLabelCtx(labelctxdef); | |||||
| EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | ||||
| labelctx->label_ctx().add_successor_list(1); | |||||
| labelctxdef->add_successor_list(1); | |||||
| EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| } | } | ||||
| @@ -501,102 +491,111 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_hardware_ctx) { | |||||
| aicaivctx->set_op_index(0); | aicaivctx->set_op_index(0); | ||||
| aicaivctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_AIV)); | aicaivctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_AIV)); | ||||
| aicaivctx->set_software_ctx_type(0); | aicaivctx->set_software_ctx_type(0); | ||||
| InitAicAivCtx(aicaivctx); | |||||
| domi::FftsPlusAicAivCtxDef *aicaivdef = aicaivctx->mutable_aic_aiv_ctx(); | |||||
| InitAicAivCtx(aicaivdef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| aicaivctx->aic_aiv_ctx().add_successor_list(1); | |||||
| aicaivdef->add_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| aicaivctx->aic_aiv_ctx().add_kernel_name("aivtest"); | |||||
| aicaivdef->add_kernel_name("aivtest"); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| aicaivctx->aic_aiv_ctx().add_src_slot(1); | |||||
| aicaivdef->add_src_slot(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *mixaicaivctx = ffts_plus_task_def->add_ffts_plus_ctx(); | domi::FftsPlusCtxDef *mixaicaivctx = ffts_plus_task_def->add_ffts_plus_ctx(); | ||||
| mixaicaivctx->set_op_index(0); | mixaicaivctx->set_op_index(0); | ||||
| mixaicaivctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_MIX_AIC)); | mixaicaivctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_MIX_AIC)); | ||||
| mixaicaivctx->set_software_ctx_type(0); | mixaicaivctx->set_software_ctx_type(0); | ||||
| InitMixAicAivCtx(mixaicaivctx); | |||||
| domi::FftsPlusMixAicAivCtxDef *mixctxdef = mixaicaivctx->mutable_mix_aic_aiv_ctx(); | |||||
| InitMixAicAivCtx(mixctxdef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| mixaicaivctx->mix_aic_aiv_ctx().add_successor_list(1); | |||||
| mixctxdef->add_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| mixaicaivctx->mix_aic_aiv_ctx().add_kernel_name("mixaiv"); | |||||
| mixctxdef->add_kernel_name("mixaiv"); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| mixaicaivctx->mix_aic_aiv_ctx().add_src_slot(1); | |||||
| mixctxdef->add_src_slot(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *notifyctx = ffts_plus_task_def->add_ffts_plus_ctx(); | domi::FftsPlusCtxDef *notifyctx = ffts_plus_task_def->add_ffts_plus_ctx(); | ||||
| notifyctx->set_op_index(0); | notifyctx->set_op_index(0); | ||||
| notifyctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_NOTIFY_WAIT)); | notifyctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_NOTIFY_WAIT)); | ||||
| notifyctx->set_software_ctx_type(0); | notifyctx->set_software_ctx_type(0); | ||||
| InitNotifyCtx(notifyctx); | |||||
| domi::FftsPlusNotifyCtxDef *notifydef = notifyctx->mutable_notify_ctx(); | |||||
| InitNotifyCtx(notifydef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| notifyctx->notify_ctx().add_successor_list(1); | |||||
| notifydef->add_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *sdmactx = ffts_plus_task_def->add_ffts_plus_ctx(); | domi::FftsPlusCtxDef *sdmactx = ffts_plus_task_def->add_ffts_plus_ctx(); | ||||
| sdmactx->set_op_index(0); | sdmactx->set_op_index(0); | ||||
| sdmactx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_SDMA)); | sdmactx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_SDMA)); | ||||
| sdmactx->set_software_ctx_type(0); | sdmactx->set_software_ctx_type(0); | ||||
| InitSdmaCtx(sdmactx); | |||||
| domi::FftsPlusSdmaCtxDef *smdadef = sdmactx->mutable_sdma_ctx(); | |||||
| InitSdmaCtx(smdadef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| sdmactx->sdma_ctx().add_successor_list(1); | |||||
| smdadef->add_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *writevalctx = ffts_plus_task_def->add_ffts_plus_ctx(); | domi::FftsPlusCtxDef *writevalctx = ffts_plus_task_def->add_ffts_plus_ctx(); | ||||
| writevalctx->set_op_index(0); | writevalctx->set_op_index(0); | ||||
| writevalctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_WRITE_VALUE)); | writevalctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_WRITE_VALUE)); | ||||
| writevalctx->set_software_ctx_type(0); | writevalctx->set_software_ctx_type(0); | ||||
| InitWriteValueCtx(writevalctx); | |||||
| domi::FftsPlusWriteValueCtxDef *writedef = writevalctx->mutable_write_value_ctx(); | |||||
| InitWriteValueCtx(writedef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| writevalctx->write_value_ctx().add_successor_list(1); | |||||
| writedef->add_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *aicpuctx = ffts_plus_task_def->add_ffts_plus_ctx(); | domi::FftsPlusCtxDef *aicpuctx = ffts_plus_task_def->add_ffts_plus_ctx(); | ||||
| aicpuctx->set_op_index(0); | aicpuctx->set_op_index(0); | ||||
| aicpuctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_AICPU)); | aicpuctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_AICPU)); | ||||
| aicpuctx->set_software_ctx_type(0); | aicpuctx->set_software_ctx_type(0); | ||||
| InitAicpuCtxCtx(aicpuctx); | |||||
| domi::FftsPlusAicpuCtxDef *aicpudef = aicpuctx->mutable_aicpu_ctx(); | |||||
| InitAicpuCtxCtx(aicpudef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| aicpuctx->aicpu_ctx().add_successor_context_id(1); | |||||
| aicpudef->add_successor_context_id(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| aicpuctx->aicpu_ctx().add_user_data(1); | |||||
| aicpudef->add_user_data(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *datactx = ffts_plus_task_def->add_ffts_plus_ctx(); | domi::FftsPlusCtxDef *datactx = ffts_plus_task_def->add_ffts_plus_ctx(); | ||||
| datactx->set_op_index(0); | datactx->set_op_index(0); | ||||
| datactx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_FLUSH_DATA)); | datactx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_FLUSH_DATA)); | ||||
| datactx->set_software_ctx_type(0); | datactx->set_software_ctx_type(0); | ||||
| InitDataCtx(datactx); | |||||
| domi::FftsPlusDataCtxDef *datadef = datactx->mutable_data_ctx(); | |||||
| InitDataCtx(datadef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| datactx->data_ctx().add_successor_list(1); | |||||
| datadef->add_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *caseswitchctx = ffts_plus_task_def->add_ffts_plus_ctx(); | domi::FftsPlusCtxDef *caseswitchctx = ffts_plus_task_def->add_ffts_plus_ctx(); | ||||
| caseswitchctx->set_op_index(0); | caseswitchctx->set_op_index(0); | ||||
| caseswitchctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | caseswitchctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | ||||
| caseswitchctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_CASE_SWITCH)); | caseswitchctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_CASE_SWITCH)); | ||||
| InitCaseSwitchCtx(caseswitchctx); | |||||
| domi::FftsPlusCaseSwitchCtxDef *caseswitchdef = caseswitchctx->mutable_case_switch_ctx(); | |||||
| InitCaseSwitchCtx(caseswitchdef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| caseswitchctx->case_switch_ctx().add_successor_list(1); | |||||
| caseswitchdef->add_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| domi::FftsPlusCtxDef *candswitchctx = ffts_plus_task_def->add_ffts_plus_ctx(); | domi::FftsPlusCtxDef *candswitchctx = ffts_plus_task_def->add_ffts_plus_ctx(); | ||||
| candswitchctx->set_op_index(0); | candswitchctx->set_op_index(0); | ||||
| candswitchctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | candswitchctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | ||||
| candswitchctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_COND_SWITCH)); | candswitchctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_COND_SWITCH)); | ||||
| InitCondSwitchCtx(candswitchctx); | |||||
| domi::FftsPlusCondSwitchCtxDef *candswitchdef = candswitchctx->mutable_cond_switch_ctx(); | |||||
| InitCondSwitchCtx(candswitchdef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| candswitchctx->cond_switch_ctx().add_true_successor_list(1); | |||||
| candswitchdef->add_true_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| candswitchctx->cond_switch_ctx().add_false_successor_list(1); | |||||
| candswitchdef->add_false_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| } | } | ||||
| @@ -625,10 +624,11 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_hardware_ctx_ex) { | |||||
| casesdefaultctx->set_op_index(0); | casesdefaultctx->set_op_index(0); | ||||
| casesdefaultctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | casesdefaultctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | ||||
| casesdefaultctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_CASE_SWITCH)); | casesdefaultctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_CASE_SWITCH)); | ||||
| InitCaseDefaultCtx(casesdefaultctx); | |||||
| domi::FftsPlusCaseDefaultCtxDef *casesdefaultdef = casesdefaultctx->mutable_case_default_ctx(); | |||||
| InitCaseDefaultCtx(casesdefaultdef); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | ||||
| casesdefaultctx->case_default_ctx().add_successor_list(1); | |||||
| casesdefaultdef->add_successor_list(1); | |||||
| EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | ||||
| } | } | ||||
| // test FftsPlusTaskInfo UpdateArgs | // test FftsPlusTaskInfo UpdateArgs | ||||
| @@ -0,0 +1,63 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <memory> | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/manager/graph_var_manager.h" | |||||
| #include "graph/ge_context.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| namespace ge { | |||||
| class UtestGraphVarManagerTest : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(UtestGraphVarManagerTest, test_get_total_memory_size) { | |||||
| size_t total_mem_size = 0; | |||||
| Status ret = VarManager::Instance(0)->GetTotalMemorySize(total_mem_size); | |||||
| EXPECT_EQ(total_mem_size, 1024UL * 1024UL * 1024UL); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_no_related_option) { | |||||
| const map<string, string> options{}; | |||||
| Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); | |||||
| EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); | |||||
| EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_graph_mem_max_size) { | |||||
| const map<string, string> options{{"ge.graphMemoryMaxSize", "536870912"}}; | |||||
| Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); | |||||
| EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); | |||||
| EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_var_mem_max_size) { | |||||
| const map<string, string> options{{"ge.variableMemoryMaxSize", "536870912"}}; | |||||
| Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); | |||||
| EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); | |||||
| EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||