From: @shenwei41 Reviewed-by: @xsmq,@lilongfei15 Signed-off-by: @lilongfei15tags/v1.2.0
| @@ -229,7 +229,7 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
| rm -rf ${BASEPATH}/cov | rm -rf ${BASEPATH}/cov | ||||
| mkdir ${BASEPATH}/cov | mkdir ${BASEPATH}/cov | ||||
| lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o cov/tmp.info | lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o cov/tmp.info | ||||
| lcov -r cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' -o cov/coverage.info | |||||
| lcov -r cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' '/usr/include/*' '*/metadef/*' '*/parser/*' -o cov/coverage.info | |||||
| cd ${BASEPATH}/cov | cd ${BASEPATH}/cov | ||||
| genhtml coverage.info | genhtml coverage.info | ||||
| fi | fi | ||||
| @@ -31,6 +31,7 @@ set(PROTO_HEADER_LIST | |||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
| protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | ||||
| protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | ||||
| protobuf_generate(ge_client PROTO_CLIENT_HEADER_SRCS PROTO_CLIENT_HEADER_HDRS ${PROTO_HEADER_LIST}) | |||||
| if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | ||||
| ############ libge_proto_common.a ############ | ############ libge_proto_common.a ############ | ||||
| @@ -56,7 +57,7 @@ target_link_libraries(ge_proto_common PRIVATE | |||||
| ############ libge_proto_client.a ############ | ############ libge_proto_client.a ############ | ||||
| add_library(ge_proto_client STATIC | add_library(ge_proto_client STATIC | ||||
| ${PROTO_HEADER_HDRS} | |||||
| ${PROTO_CLIENT_HEADER_HDRS} | |||||
| ${PROTO_CLIENT_SRCS} | ${PROTO_CLIENT_SRCS} | ||||
| ) | ) | ||||
| @@ -65,6 +66,11 @@ target_compile_definitions(ge_proto_client PRIVATE | |||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_include_directories(ge_proto_client PRIVATE | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_client | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_client/proto | |||||
| ) | |||||
| target_compile_options(ge_proto_client PRIVATE | target_compile_options(ge_proto_client PRIVATE | ||||
| -O2 | -O2 | ||||
| -fno-common | -fno-common | ||||
| @@ -16,6 +16,7 @@ set(PROTO_LIST | |||||
| ) | ) | ||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
| protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) | |||||
| set(SRC_LIST | set(SRC_LIST | ||||
| "context/ctx.cc" | "context/ctx.cc" | ||||
| @@ -127,7 +128,7 @@ target_link_libraries(ge_common PRIVATE | |||||
| ) | ) | ||||
| ############ libge_common.a ############ | ############ libge_common.a ############ | ||||
| add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_HDRS}) | |||||
| add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_STATIC_HDRS}) | |||||
| target_compile_definitions(ge_common_static PRIVATE | target_compile_definitions(ge_common_static PRIVATE | ||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
| HOST_VISIBILITY | HOST_VISIBILITY | ||||
| @@ -158,7 +159,7 @@ target_include_directories(ge_common_static PRIVATE | |||||
| ${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
| ${METADEF_DIR}/inc/graph | ${METADEF_DIR}/inc/graph | ||||
| ${CMAKE_BINARY_DIR} | ${CMAKE_BINARY_DIR} | ||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_static | |||||
| #### yellow zone #### | #### yellow zone #### | ||||
| ${GE_DEPEND_DIR}/inc | ${GE_DEPEND_DIR}/inc | ||||
| ${GE_DEPEND_DIR}/inc/cce | ${GE_DEPEND_DIR}/inc/cce | ||||
| @@ -80,13 +80,11 @@ Status OpdebugRegister::RegisterDebugForStream(rtStream_t stream, uint32_t op_de | |||||
| uint32_t debug_stream_id = 0; | uint32_t debug_stream_id = 0; | ||||
| uint32_t debug_task_id = 0; | uint32_t debug_task_id = 0; | ||||
| #ifdef ONLY_COMPILE_OPEN_SRC | |||||
| auto rt_ret = rtDebugRegisterForStream(stream, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id); | auto rt_ret = rtDebugRegisterForStream(stream, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "rtDebugRegisterForStream error, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "rtDebugRegisterForStream error, ret: 0x%X", rt_ret); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | return RT_ERROR_TO_GE_STATUS(rt_ret); | ||||
| } | } | ||||
| #endif | |||||
| GELOGD("debug_task_id:%u, debug_stream_id:%u in stream overflow.", debug_task_id, debug_stream_id); | GELOGD("debug_task_id:%u, debug_stream_id:%u in stream overflow.", debug_task_id, debug_stream_id); | ||||
| data_dumper.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, true); | data_dumper.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, true); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -94,7 +92,6 @@ Status OpdebugRegister::RegisterDebugForStream(rtStream_t stream, uint32_t op_de | |||||
| void OpdebugRegister::UnregisterDebugForStream(rtStream_t stream) { | void OpdebugRegister::UnregisterDebugForStream(rtStream_t stream) { | ||||
| rtError_t rt_ret = RT_ERROR_NONE; | rtError_t rt_ret = RT_ERROR_NONE; | ||||
| #ifdef ONLY_COMPILE_OPEN_SRC | |||||
| if (stream != nullptr) { | if (stream != nullptr) { | ||||
| GELOGD("start call rtDebugUnRegisterForStream in unknown shape over flow."); | GELOGD("start call rtDebugUnRegisterForStream in unknown shape over flow."); | ||||
| rt_ret = rtDebugUnRegisterForStream(stream); | rt_ret = rtDebugUnRegisterForStream(stream); | ||||
| @@ -102,8 +99,6 @@ void OpdebugRegister::UnregisterDebugForStream(rtStream_t stream) { | |||||
| GELOGW("rtDebugUnRegisterForStream failed, ret: 0x%X", rt_ret); | GELOGW("rtDebugUnRegisterForStream failed, ret: 0x%X", rt_ret); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| if (op_debug_addr_ != nullptr) { | if (op_debug_addr_ != nullptr) { | ||||
| rt_ret = rtFree(op_debug_addr_); | rt_ret = rtFree(op_debug_addr_); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| @@ -8,6 +8,7 @@ set(PROTO_LIST | |||||
| ) | ) | ||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
| protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) | |||||
| set(SRC_LIST | set(SRC_LIST | ||||
| "ge_executor.cc" | "ge_executor.cc" | ||||
| @@ -162,7 +163,7 @@ set(SRC_LIST | |||||
| ) | ) | ||||
| ######## libge_executor.a ######## | ######## libge_executor.a ######## | ||||
| add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_HDRS}) | |||||
| add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_STATIC_HDRS}) | |||||
| target_compile_options(ge_executor PRIVATE | target_compile_options(ge_executor PRIVATE | ||||
| $<$<OR:$<STREQUAL:${TARGET_SYSTEM_NAME},Linux>,$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-fvisibility=hidden -O2 -Werror -Wno-deprecated-declarations -fno-common> | $<$<OR:$<STREQUAL:${TARGET_SYSTEM_NAME},Linux>,$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-fvisibility=hidden -O2 -Werror -Wno-deprecated-declarations -fno-common> | ||||
| @@ -191,7 +192,7 @@ target_include_directories(ge_executor SYSTEM PRIVATE | |||||
| ${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
| ${METADEF_DIR}/inc/graph | ${METADEF_DIR}/inc/graph | ||||
| ${CMAKE_BINARY_DIR} | ${CMAKE_BINARY_DIR} | ||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_static | |||||
| #### yellow zone #### | #### yellow zone #### | ||||
| ${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
| ${GE_CODE_DIR}/../inc/cce | ${GE_CODE_DIR}/../inc/cce | ||||
| @@ -20,6 +20,8 @@ set(OPS_KERNEL_SRC_LIST | |||||
| ) | ) | ||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
| protobuf_generate(ge_ops_shared PROTO_OPS_SHARED_SRCS PROTO_OPS_SHARED_HDRS ${PROTO_LIST}) | |||||
| protobuf_generate(ge_ops_static PROTO_OPS_STATIC_SRCS PROTO_OPS_STATIC_HDRS ${PROTO_LIST}) | |||||
| ############ libge_local_engine.so ############ | ############ libge_local_engine.so ############ | ||||
| add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) | add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) | ||||
| @@ -119,7 +121,7 @@ set_target_properties(atc_ge_local_engine PROPERTIES | |||||
| ) | ) | ||||
| ############ libge_local_opskernel_builder.so ############ | ############ libge_local_opskernel_builder.so ############ | ||||
| add_library(ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) | |||||
| add_library(ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_SHARED_HDRS}) | |||||
| target_compile_options(ge_local_opskernel_builder PRIVATE | target_compile_options(ge_local_opskernel_builder PRIVATE | ||||
| -Werror | -Werror | ||||
| @@ -143,7 +145,7 @@ target_include_directories(ge_local_opskernel_builder PRIVATE | |||||
| ${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
| ${METADEF_DIR}/inc/graph | ${METADEF_DIR}/inc/graph | ||||
| ${CMAKE_BINARY_DIR} | ${CMAKE_BINARY_DIR} | ||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_ops_shared | |||||
| #### yellow zone #### | #### yellow zone #### | ||||
| ${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
| #### blue zone #### | #### blue zone #### | ||||
| @@ -166,7 +168,7 @@ target_link_libraries(ge_local_opskernel_builder PRIVATE | |||||
| ) | ) | ||||
| ############ atclib/libge_local_opskernel_builder.so ############ | ############ atclib/libge_local_opskernel_builder.so ############ | ||||
| add_library(atc_ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) | |||||
| add_library(atc_ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_SHARED_HDRS}) | |||||
| target_compile_options(atc_ge_local_opskernel_builder PRIVATE | target_compile_options(atc_ge_local_opskernel_builder PRIVATE | ||||
| -Werror | -Werror | ||||
| @@ -190,7 +192,7 @@ target_include_directories(atc_ge_local_opskernel_builder PRIVATE | |||||
| ${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
| ${METADEF_DIR}/inc/graph | ${METADEF_DIR}/inc/graph | ||||
| ${CMAKE_BINARY_DIR} | ${CMAKE_BINARY_DIR} | ||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_ops_shared | |||||
| #### yellow zone #### | #### yellow zone #### | ||||
| ${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
| #### blue zone #### | #### blue zone #### | ||||
| @@ -218,7 +220,7 @@ set_target_properties(atc_ge_local_opskernel_builder PROPERTIES | |||||
| ) | ) | ||||
| ############ libge_local_opskernel_builder.a ############ | ############ libge_local_opskernel_builder.a ############ | ||||
| add_library(ge_local_opskernel_builder_static STATIC ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) | |||||
| add_library(ge_local_opskernel_builder_static STATIC ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_STATIC_HDRS}) | |||||
| target_compile_options(ge_local_opskernel_builder_static PRIVATE | target_compile_options(ge_local_opskernel_builder_static PRIVATE | ||||
| -Werror | -Werror | ||||
| @@ -243,7 +245,7 @@ target_include_directories(ge_local_opskernel_builder_static PRIVATE | |||||
| ${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
| ${METADEF_DIR}/inc/graph | ${METADEF_DIR}/inc/graph | ||||
| ${CMAKE_BINARY_DIR} | ${CMAKE_BINARY_DIR} | ||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_ops_static | |||||
| #### yellow zone #### | #### yellow zone #### | ||||
| ${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
| #### blue zone #### | #### blue zone #### | ||||
| @@ -49,6 +49,7 @@ const char *const kIsLastNode = "is_last_node"; | |||||
| const char *const kIsInputVar = "INPUT_IS_VAR"; | const char *const kIsInputVar = "INPUT_IS_VAR"; | ||||
| const char *const kIsOutputVar = "OUTPUT_IS_VAR"; | const char *const kIsOutputVar = "OUTPUT_IS_VAR"; | ||||
| const char *const kProfilingMode = "PROFILING_MODE"; | const char *const kProfilingMode = "PROFILING_MODE"; | ||||
| const char *const kIteratorV2 = "IteratorV2"; | |||||
| const uint32_t kProfilingArStep = 2; | const uint32_t kProfilingArStep = 2; | ||||
| const uint64_t kProfilingFpStartLogid = 1; | const uint64_t kProfilingFpStartLogid = 1; | ||||
| const uint64_t kProfilingBpEndLogid = 2; | const uint64_t kProfilingBpEndLogid = 2; | ||||
| @@ -57,6 +58,7 @@ const uint64_t kProfilingArEndLogid = 4; | |||||
| const uint64_t kProfilingIterEndLogid = 65535; | const uint64_t kProfilingIterEndLogid = 65535; | ||||
| const int64_t kHashFactor = 100000; | const int64_t kHashFactor = 100000; | ||||
| const int64_t kInvalidGroupId = -1; | const int64_t kInvalidGroupId = -1; | ||||
| const std::set<std::string> kFpNodeTypes = {ge::DATA, ge::GETNEXT, kIteratorV2}; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| TaskGenerator::TaskGenerator(uint8_t *var_mem_base, uint64_t var_mem_size) { | TaskGenerator::TaskGenerator(uint8_t *var_mem_base, uint64_t var_mem_size) { | ||||
| @@ -621,8 +623,10 @@ Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||||
| if (op_kernel_lib_name.empty()) { | if (op_kernel_lib_name.empty()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (op_desc->GetType() == GETNEXT || op_desc->GetType() == DATA) { | |||||
| auto type = op_desc->GetType(); | |||||
| std::string original_type; | |||||
| (void)AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); | |||||
| if (kFpNodeTypes.find(type) != kFpNodeTypes.end() || kFpNodeTypes.find(original_type) != kFpNodeTypes.end()) { | |||||
| auto out_anchor = node->GetOutDataAnchor(0); | auto out_anchor = node->GetOutDataAnchor(0); | ||||
| for (auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { | for (auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { | ||||
| GE_CHECK_NOTNULL(peer_in_anchor); | GE_CHECK_NOTNULL(peer_in_anchor); | ||||
| @@ -356,6 +356,14 @@ void CachingAllocator::FreeBlocks() { | |||||
| (void) FreeCachedBlocks(); | (void) FreeCachedBlocks(); | ||||
| } | } | ||||
| void CachingAllocator::TryFreeBlocks() { | |||||
| GELOGI("Try free blocks."); | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| if (allocated_blocks_.empty()) { | |||||
| (void) FreeCachedBlocks(); | |||||
| } | |||||
| } | |||||
| void CachingAllocator::FreeBlockBins() { | void CachingAllocator::FreeBlockBins() { | ||||
| GELOGI("Free block bins."); | GELOGI("Free block bins."); | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| @@ -94,6 +94,13 @@ class CachingAllocator { | |||||
| /// | /// | ||||
| Status Free(uint8_t *memory_addr, uint32_t device_id = 0); | Status Free(uint8_t *memory_addr, uint32_t device_id = 0); | ||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief try to free memory when no memory is referenced | |||||
| /// @return void | |||||
| /// | |||||
| void TryFreeBlocks(); | |||||
| private: | private: | ||||
| /// | /// | ||||
| @@ -126,11 +126,11 @@ bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int6 | |||||
| bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { | bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { | ||||
| OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
| std::map<string, std::map<int, int>> node_workspace_offset; | |||||
| std::map<string, std::map<int64_t, int64_t>> atomic_workspace_index_size; | |||||
| bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | ||||
| bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | ||||
| node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); | |||||
| if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { | |||||
| atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size); | |||||
| if (!has_atomic_input && has_atomic_output && atomic_workspace_index_size.empty()) { | |||||
| std::vector<int64_t> atomic_output_index; | std::vector<int64_t> atomic_output_index; | ||||
| (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); | (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); | ||||
| bool is_all_output_peer_also_atomic = true; | bool is_all_output_peer_also_atomic = true; | ||||
| @@ -332,11 +332,11 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { | |||||
| } | } | ||||
| // 2.Check atomic attr in node | // 2.Check atomic attr in node | ||||
| std::map<string, std::map<int, int>> node_workspace_offset; | |||||
| std::map<string, std::map<int64_t, int64_t>> atomic_workspace_index_size; | |||||
| bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | ||||
| bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | ||||
| node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); | |||||
| if (!has_atomic_input && !has_atomic_output && node_workspace_offset.empty()) { | |||||
| atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size); | |||||
| if (!has_atomic_input && !has_atomic_output && atomic_workspace_index_size.empty()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -137,7 +137,6 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| stream_label = node->GetInDataNodes().at(0)->GetName(); | stream_label = node->GetInDataNodes().at(0)->GetName(); | ||||
| GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | |||||
| bool value = false; | bool value = false; | ||||
| OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| @@ -35,9 +35,9 @@ | |||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "utils/node_utils.h" | |||||
| namespace ge { | namespace ge { | ||||
| Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector<int64_t> &data, | Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector<int64_t> &data, | ||||
| std::vector<GeTensorPtr> &v_output, const bool scalar_output) { | std::vector<GeTensorPtr> &v_output, const bool scalar_output) { | ||||
| Status ret = SUCCESS; | Status ret = SUCCESS; | ||||
| @@ -246,6 +246,12 @@ NodePtr PassUtils::GetInDataNode(const ConstNodePtr &node, int index) { | |||||
| return src_node; | return src_node; | ||||
| } | } | ||||
| NodePtr PassUtils::GetInNodeCrossSubgraphByIndex(const ConstNodePtr &node, int index) { | |||||
| auto src_node = GetInDataNode(node, index); | |||||
| return NodeUtils::GetInNodeCrossSubgraph(src_node); | |||||
| } | |||||
| bool PassUtils::IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph) { | bool PassUtils::IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph) { | ||||
| if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
| return false; | return false; | ||||
| @@ -30,6 +30,8 @@ class PassUtils { | |||||
| static NodePtr GetInDataNode(const ConstNodePtr &node, int index); | static NodePtr GetInDataNode(const ConstNodePtr &node, int index); | ||||
| static NodePtr GetInNodeCrossSubgraphByIndex(const ConstNodePtr &node, int index); | |||||
| static bool IsConstant(const ConstNodePtr &node); | static bool IsConstant(const ConstNodePtr &node); | ||||
| static Status SetOutNodeWeight(const OutDataAnchorPtr &out_data_anchor, const NodePtr &src_node); | static Status SetOutNodeWeight(const OutDataAnchorPtr &out_data_anchor, const NodePtr &src_node); | ||||
| @@ -279,7 +279,7 @@ Status SubexpressionMigrationPass::GraphNodeMigration(const ComputeGraphPtr &gra | |||||
| const auto &in_anchor = in_anchors.at(i); | const auto &in_anchor = in_anchors.at(i); | ||||
| const auto &base_node = in_anchor->GetOwnerNode(); | const auto &base_node = in_anchor->GetOwnerNode(); | ||||
| GELOGD("Get Data direct node: %s", base_node->GetName().c_str()); | GELOGD("Get Data direct node: %s", base_node->GetName().c_str()); | ||||
| if (!base_node->GetHostNode()) { | |||||
| if (!base_node->GetHostNode() || base_node->GetType() == SWITCH) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -94,6 +94,12 @@ Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pre | |||||
| GELOGE(FAILED, "parameter is null."); | GELOGE(FAILED, "parameter is null."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // If two nodes aren't in same graph, get node's direct in_node instead of pred_node. | |||||
| if (node->GetOwnerComputeGraph() != pred_node->GetOwnerComputeGraph()) { | |||||
| pred_node = PassUtils::GetInDataNode(node, kPredInputIndex); | |||||
| } | |||||
| // link pred's in control nodes to switch | // link pred's in control nodes to switch | ||||
| if (GraphUtils::CopyInCtrlEdges(pred_node, node) != GRAPH_SUCCESS) { | if (GraphUtils::CopyInCtrlEdges(pred_node, node) != GRAPH_SUCCESS) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -131,7 +137,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto pred_node = PassUtils::GetInDataNode(node, kPredInputIndex); | |||||
| auto pred_node = PassUtils::GetInNodeCrossSubgraphByIndex(node, kPredInputIndex); | |||||
| if (pred_node == nullptr) { | if (pred_node == nullptr) { | ||||
| GELOGD("[%s] Pred input is null.", node->GetName().c_str()); | GELOGD("[%s] Pred input is null.", node->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -143,7 +149,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto input_node = PassUtils::GetInDataNode(node, kDataInputIndex); | |||||
| auto input_node = PassUtils::GetInNodeCrossSubgraphByIndex(node, kDataInputIndex); | |||||
| if (input_node == nullptr) { | if (input_node == nullptr) { | ||||
| GELOGD("[%s] Data input is null.", node->GetName().c_str()); | GELOGD("[%s] Data input is null.", node->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -448,6 +448,8 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
| // select first stream_switch | // select first stream_switch | ||||
| NodePtr stream_switch = switch_list.front(); | NodePtr stream_switch = switch_list.front(); | ||||
| // set stream_label | |||||
| GE_CHK_STATUS_RET(SetStreamLabel(stream_switch, cast_node->GetName()), "Set stream label failed."); | |||||
| OpDescPtr switch_desc = stream_switch->GetOpDesc(); | OpDescPtr switch_desc = stream_switch->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(switch_desc); | GE_CHECK_NOTNULL(switch_desc); | ||||
| switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"))); | switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"))); | ||||
| @@ -3,6 +3,7 @@ set(PROTO_LIST | |||||
| ) | ) | ||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
| protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) | |||||
| set(SRC_LIST | set(SRC_LIST | ||||
| "engine/host_cpu_engine.cc" | "engine/host_cpu_engine.cc" | ||||
| @@ -61,7 +62,7 @@ target_link_libraries(host_cpu_engine PRIVATE | |||||
| ) | ) | ||||
| ############ atcstub/libhost_cpu_engine.so ############ | ############ atcstub/libhost_cpu_engine.so ############ | ||||
| add_library(atc_host_cpu_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||||
| add_library(atc_host_cpu_engine SHARED ${SRC_LIST} ${PROTO_ATCSTUB_HDRS}) | |||||
| target_compile_options(atc_host_cpu_engine PRIVATE | target_compile_options(atc_host_cpu_engine PRIVATE | ||||
| -Werror | -Werror | ||||
| @@ -84,7 +85,7 @@ target_include_directories(atc_host_cpu_engine PRIVATE | |||||
| ${METADEF_DIR}/inc/external | ${METADEF_DIR}/inc/external | ||||
| ${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
| ${CMAKE_BINARY_DIR} | ${CMAKE_BINARY_DIR} | ||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ${CMAKE_BINARY_DIR}/proto/ge_atcstub | |||||
| #### yellow zone #### | #### yellow zone #### | ||||
| ${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
| #### blue zone #### | #### blue zone #### | ||||
| @@ -71,6 +71,7 @@ struct GraphExecutionContext { | |||||
| std::atomic_bool is_eos_; | std::atomic_bool is_eos_; | ||||
| long profiling_level = 0; | long profiling_level = 0; | ||||
| long iteration = 0; | long iteration = 0; | ||||
| void *global_step = nullptr; | |||||
| private: | private: | ||||
| Status status = SUCCESS; | Status status = SUCCESS; | ||||
| @@ -33,6 +33,9 @@ HybridModelExecutor::~HybridModelExecutor() { | |||||
| if (context_.rt_gen_context != nullptr) { | if (context_.rt_gen_context != nullptr) { | ||||
| (void) rtCtxDestroy(context_.rt_gen_context); | (void) rtCtxDestroy(context_.rt_gen_context); | ||||
| } | } | ||||
| if (context_.global_step != nullptr) { | |||||
| (void) rtFree(context_.global_step); | |||||
| } | |||||
| } | } | ||||
| Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
| @@ -47,6 +50,8 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| auto root_graph_item = model_->GetRootGraphItem(); | auto root_graph_item = model_->GetRootGraphItem(); | ||||
| GE_CHECK_NOTNULL(root_graph_item); | GE_CHECK_NOTNULL(root_graph_item); | ||||
| GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | |||||
| sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | |||||
| SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | ||||
| auto ret = ExecuteGraphInternal(executor, args); | auto ret = ExecuteGraphInternal(executor, args); | ||||
| Cleanup(); | Cleanup(); | ||||
| @@ -97,6 +102,7 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
| GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | ||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | ||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
| GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM)); | |||||
| context_.stream = stream_; | context_.stream = stream_; | ||||
| context_.model = model_; | context_.model = model_; | ||||
| @@ -231,12 +231,6 @@ Status NodeDoneCallback::DumpDynamicNode() { | |||||
| uint32_t model_id = model->GetModelId(); | uint32_t model_id = model->GetModelId(); | ||||
| dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); | dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); | ||||
| void *global_step = nullptr; | |||||
| TensorValue *varible_global_step = context_->GetVariable(NODE_NAME_GLOBAL_STEP); | |||||
| if (varible_global_step != nullptr) { | |||||
| global_step = const_cast<void *>(varible_global_step->GetData()); | |||||
| } | |||||
| void *loop_per_iter = nullptr; | void *loop_per_iter = nullptr; | ||||
| TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); | TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); | ||||
| if (varible_loop_per_iter != nullptr) { | if (varible_loop_per_iter != nullptr) { | ||||
| @@ -248,6 +242,7 @@ Status NodeDoneCallback::DumpDynamicNode() { | |||||
| if (varible_loop_cond != nullptr) { | if (varible_loop_cond != nullptr) { | ||||
| loop_cond = const_cast<void *>(varible_loop_cond->GetData()); | loop_cond = const_cast<void *>(varible_loop_cond->GetData()); | ||||
| } | } | ||||
| void *global_step = context_->GetExecutionContext()->global_step; | |||||
| dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); | dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); | ||||
| GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); | GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); | ||||
| @@ -255,9 +255,7 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
| (void) AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false); | (void) AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false); | ||||
| (void) AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false); | (void) AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false); | ||||
| new_node->node_id = node_index; | |||||
| new_node->op_desc->SetId(node_index); | |||||
| node_index += 1; | |||||
| new_node->node_id = static_cast<int>(new_node->op_desc->GetId()); | |||||
| NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | ||||
| new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) || | new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) || | ||||
| (executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) || | (executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) || | ||||
| @@ -273,16 +271,16 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt | |||||
| // not care result, if no this attr, stand for the op does not need force infershape | // not care result, if no this attr, stand for the op does not need force infershape | ||||
| (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | ||||
| GELOGD("node [%s] is need do infershape , flag is %d", | GELOGD("node [%s] is need do infershape , flag is %d", | ||||
| op_desc->GetName().c_str(), | |||||
| node_item.is_need_force_infershape); | |||||
| op_desc->GetName().c_str(), | |||||
| node_item.is_need_force_infershape); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) { | Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) { | ||||
| std::set<NodePtr> dependent_input_nodes; | |||||
| std::set<NodePtr> dependent_for_shape_inference; | |||||
| std::set<NodePtr> dependent_for_execution; | |||||
| auto &ge_node = node_item.node; | auto &ge_node = node_item.node; | ||||
| bool is_hccl_op = | |||||
| NodeExecutorManager::GetInstance().ResolveExecutorType(*ge_node) == NodeExecutorManager::ExecutorType::HCCL; | |||||
| bool is_hccl_op = node_item.IsHcclOp(); | |||||
| // The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE. | // The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE. | ||||
| // Wait for these parent nodes before execution. | // Wait for these parent nodes before execution. | ||||
| @@ -297,29 +295,15 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| auto src_node_item = MutableNodeItem(src_node); | auto src_node_item = MutableNodeItem(src_node); | ||||
| GE_CHECK_NOTNULL(src_node_item); | GE_CHECK_NOTNULL(src_node_item); | ||||
| if (is_hccl_op) { | |||||
| GELOGD("[%s] Add input data dependent node [%s] due to engine type is HCCL", | |||||
| node_item.NodeName().c_str(), | |||||
| src_node_item->NodeName().c_str()); | |||||
| src_node_item->has_observer = true; | |||||
| node_item.dependents_for_execution.emplace_back(src_node); | |||||
| node_item.has_observer = true; | |||||
| for (auto &dst_node : ge_node->GetOutNodes()) { | |||||
| if (dst_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(dst_node, &dst_node_item)); | |||||
| dst_node_item->dependents_for_execution.emplace_back(ge_node); | |||||
| } | |||||
| } else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { | |||||
| GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", | |||||
| node_item.NodeName().c_str(), | |||||
| src_node_item->NodeName().c_str()); | |||||
| if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) { | |||||
| GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d", | |||||
| ge_node->GetName().c_str(), | |||||
| ge_node->GetType().c_str(), | |||||
| src_node->GetName().c_str(), | |||||
| src_node->GetType().c_str(), | |||||
| static_cast<int>(src_node_item->shape_inference_type)); | |||||
| src_node_item->has_observer = true; | src_node_item->has_observer = true; | ||||
| node_item.dependents_for_execution.emplace_back(src_node); | |||||
| dependent_for_execution.emplace(src_node); | |||||
| } | } | ||||
| if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) { | if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) { | ||||
| @@ -327,22 +311,17 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| src_node_item->NodeName().c_str()); | src_node_item->NodeName().c_str()); | ||||
| src_node_item->has_observer = true; | src_node_item->has_observer = true; | ||||
| dependent_input_nodes.emplace(src_node); | |||||
| dependent_for_shape_inference.emplace(src_node); | |||||
| } | } | ||||
| } | } | ||||
| // cond or branch need to be prepared before the execution of IF or CASE | // cond or branch need to be prepared before the execution of IF or CASE | ||||
| if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | ||||
| const auto &in_anchor = ge_node->GetInDataAnchor(0); | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_anchor); | |||||
| auto src_node = peer_anchor->GetOwnerNode(); | |||||
| auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | |||||
| GE_CHECK_NOTNULL(src_node); | GE_CHECK_NOTNULL(src_node); | ||||
| auto src_node_item = MutableNodeItem(src_node); | auto src_node_item = MutableNodeItem(src_node); | ||||
| GE_CHECK_NOTNULL(src_node_item); | GE_CHECK_NOTNULL(src_node_item); | ||||
| src_node_item->has_observer = true; | |||||
| node_item.dependents_for_execution.emplace_back(src_node); | |||||
| dependent_for_execution.emplace(src_node); | |||||
| GELOGD("[%s] Dependent added from %s for control op's cond/branch", | GELOGD("[%s] Dependent added from %s for control op's cond/branch", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| src_node_item->NodeName().c_str()); | src_node_item->NodeName().c_str()); | ||||
| @@ -366,24 +345,32 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| GE_CHECK_NOTNULL(src_node); | GE_CHECK_NOTNULL(src_node); | ||||
| auto src_node_item = MutableNodeItem(src_node); | auto src_node_item = MutableNodeItem(src_node); | ||||
| src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); | src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); | ||||
| src_node_item->has_observer = true; | |||||
| dependent_input_nodes.emplace(src_node); | |||||
| dependent_for_shape_inference.emplace(src_node); | |||||
| GELOGD("[%s] Dependent added from output of [%s:%d]", | GELOGD("[%s] Dependent added from output of [%s:%d]", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| src_node_item->NodeName().c_str(), | src_node_item->NodeName().c_str(), | ||||
| peer_out_anchor->GetIdx()); | peer_out_anchor->GetIdx()); | ||||
| } | } | ||||
| for (const auto &dep_node : dependent_input_nodes) { | |||||
| GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference)); | |||||
| for (const auto &dep_node : dependent_for_shape_inference) { | |||||
| auto src_node_item = MutableNodeItem(dep_node); | |||||
| GE_CHECK_NOTNULL(src_node_item); | |||||
| src_node_item->has_observer = true; | |||||
| node_item.dependents_for_shape_inference.emplace_back(dep_node); | node_item.dependents_for_shape_inference.emplace_back(dep_node); | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item)); | |||||
| for (const auto &dep_node : dependent_for_execution) { | |||||
| auto src_node_item = MutableNodeItem(dep_node); | |||||
| GE_CHECK_NOTNULL(src_node_item); | |||||
| src_node_item->has_observer = true; | |||||
| node_item.dependents_for_execution.emplace_back(dep_node); | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { | |||||
| Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies) { | |||||
| if (node_item.fused_subgraph == nullptr) { | if (node_item.fused_subgraph == nullptr) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -413,17 +400,12 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| op_desc->GetName().c_str(), | op_desc->GetName().c_str(), | ||||
| src_node_item->NodeName().c_str()); | src_node_item->NodeName().c_str()); | ||||
| src_node_item->has_observer = true; | |||||
| src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); | src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); | ||||
| auto &depends = node_item.dependents_for_shape_inference; | |||||
| if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) { | |||||
| depends.emplace_back(src_node); | |||||
| GELOGD("[%s] Dependent added from output of [%s:%d]", | |||||
| node_item.NodeName().c_str(), | |||||
| src_node_item->NodeName().c_str(), | |||||
| peer_out_anchor->GetIdx()); | |||||
| } | |||||
| dependencies.emplace(src_node); | |||||
| GELOGD("[%s] Dependent added from output of [%s:%d]", | |||||
| node_item.NodeName().c_str(), | |||||
| src_node_item->NodeName().c_str(), | |||||
| peer_out_anchor->GetIdx()); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -770,9 +752,23 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | ||||
| root_graph->GetDirectNodesSize(), | root_graph->GetDirectNodesSize(), | ||||
| root_graph->GetAllNodesSize()); | root_graph->GetAllNodesSize()); | ||||
| GE_DUMP(root_graph, "hybrid_merged_graph"); | |||||
| } | } | ||||
| root_graph_ = root_graph; | |||||
| // Reset node id by topological order across all subgraphs | |||||
| int64_t index = 0; | |||||
| for (const auto &node : root_graph->GetAllNodes()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto parent_graph = node->GetOwnerComputeGraph(); | |||||
| // No need to update nodes in known subgraph | |||||
| if (parent_graph != nullptr && !parent_graph->GetGraphUnknownFlag()) { | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| op_desc->SetId(index++); | |||||
| } | |||||
| GE_DUMP(root_graph, "hybrid_merged_graph"); | |||||
| GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); | GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); | ||||
| GELOGD("Done loading root graph successfully."); | GELOGD("Done loading root graph successfully."); | ||||
| GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph"); | GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph"); | ||||
| @@ -810,6 +806,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| } | } | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops"); | |||||
| GELOGI("Done loading all subgraphs successfully."); | GELOGI("Done loading all subgraphs successfully."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1075,25 +1072,41 @@ Status HybridModelBuilder::InitWeights() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::LoadTask(NodeItem &node_item) { | |||||
| auto &node_ptr = node_item.node; | |||||
| GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); | |||||
| auto load_ret = node_item.node_executor->LoadTask(hybrid_model_, | |||||
| node_ptr, | |||||
| node_item.kernel_task); | |||||
| if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { | |||||
| GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); | |||||
| return load_ret; | |||||
| } | |||||
| GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::LoadTasks() { | Status HybridModelBuilder::LoadTasks() { | ||||
| GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed."); | GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed."); | ||||
| std::map<int, std::map<std::string, NodeItem *>> ordered_partitioned_calls; | |||||
| for (auto &it : hybrid_model_.node_items_) { | for (auto &it : hybrid_model_.node_items_) { | ||||
| auto &node_item = it.second; | auto &node_item = it.second; | ||||
| auto &node_ptr = node_item->node; | |||||
| if (node_item->node_type == NETOUTPUT) { | if (node_item->node_type == NETOUTPUT) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); | |||||
| auto load_ret = node_item->node_executor->LoadTask(hybrid_model_, | |||||
| node_ptr, | |||||
| node_item->kernel_task); | |||||
| if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { | |||||
| GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); | |||||
| return load_ret; | |||||
| if (node_item->node_type == PARTITIONEDCALL) { | |||||
| ordered_partitioned_calls[node_item->node_id][node_item->node_name] = node_item.get(); | |||||
| continue; | |||||
| } | } | ||||
| GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item)); | |||||
| } | |||||
| GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); | |||||
| // HCCL operators need to be loaded in the same order across different processes | |||||
| for (auto &it : ordered_partitioned_calls) { | |||||
| for (auto &it2 : it.second) { | |||||
| GE_CHK_STATUS_RET_NOLOG(LoadTask(*it2.second)); | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -1626,6 +1639,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem | |||||
| auto temp_graph = MakeShared<ComputeGraph>("temp"); | auto temp_graph = MakeShared<ComputeGraph>("temp"); | ||||
| GE_CHECK_NOTNULL(temp_graph); | GE_CHECK_NOTNULL(temp_graph); | ||||
| auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); | auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); | ||||
| wrapper_op_desc->SetId(parent_node_item->node_id); | |||||
| GeModelPtr ge_model = subgraph_models_[subgraph_name]; | GeModelPtr ge_model = subgraph_models_[subgraph_name]; | ||||
| GE_CHECK_NOTNULL(ge_model); | GE_CHECK_NOTNULL(ge_model); | ||||
| hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); | hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); | ||||
| @@ -2011,5 +2025,93 @@ Status HybridModelBuilder::CheckAicpuOpList() { | |||||
| "Launch check aicpu op type failed."); | "Launch check aicpu op type failed."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||||
| const auto &node = node_item->node; | |||||
| auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | |||||
| if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { | |||||
| std::string parallel_group; | |||||
| if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { | |||||
| GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str()); | |||||
| parallel_group_to_nodes_[parallel_group].emplace(node_item); | |||||
| std::set<std::string> group{parallel_group}; | |||||
| node_to_parallel_groups_[node_item].emplace(parallel_group); | |||||
| } | |||||
| } else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) { | |||||
| std::set<std::string> parallel_groups; | |||||
| GELOGD("[%s] To collect parallel group for known-shaped subgraph", node_item->NodeName().c_str()); | |||||
| for (const auto &subgraph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) { | |||||
| GELOGD("[%s] Start to get parallel group from subgraph: %s", | |||||
| node_item->NodeName().c_str(), | |||||
| subgraph_name.c_str()); | |||||
| auto subgraph = root_graph_->GetSubgraph(subgraph_name); | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| for (const auto &sub_node : subgraph->GetAllNodes()) { | |||||
| std::string parallel_group; | |||||
| if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { | |||||
| GELOGD("[%s::%s] Got parallel group = %s", | |||||
| subgraph_name.c_str(), | |||||
| sub_node->GetName().c_str(), | |||||
| parallel_group.c_str()); | |||||
| parallel_groups.emplace(parallel_group); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!parallel_groups.empty()) { | |||||
| for (const auto ¶llel_group : parallel_groups) { | |||||
| parallel_group_to_nodes_[parallel_group].emplace(node_item); | |||||
| GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str()); | |||||
| } | |||||
| node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups)); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::ParseDependentByParallelGroup() { | |||||
| for (auto &it : hybrid_model_.node_items_) { | |||||
| GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get())); | |||||
| } | |||||
| for (const auto &it : node_to_parallel_groups_) { | |||||
| auto node_item = it.first; | |||||
| auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); | |||||
| for (const auto ¶llel_group : it.second) { | |||||
| auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; | |||||
| NodeItem *nearest_dep_node = nullptr; | |||||
| int max_id = -1; | |||||
| for (auto &dep_node : dependent_nodes) { | |||||
| if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) { | |||||
| nearest_dep_node = dep_node; | |||||
| max_id = dep_node->node_id; | |||||
| } | |||||
| } | |||||
| if (nearest_dep_node != nullptr) { | |||||
| GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str()); | |||||
| auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node); | |||||
| if (src_engine_type == dst_executor_type) { | |||||
| GELOGD("No need to add dependency for nodes with same executor type"); | |||||
| continue; | |||||
| } | |||||
| auto &deps = node_item->dependents_for_execution; | |||||
| if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) { | |||||
| GELOGD("%s->%s Already has dependency, skip it", | |||||
| nearest_dep_node->node->GetName().c_str(), | |||||
| node_item->NodeName().c_str()); | |||||
| continue; | |||||
| } | |||||
| nearest_dep_node->has_observer = true; | |||||
| deps.emplace_back(nearest_dep_node->node); | |||||
| GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", | |||||
| parallel_group.c_str(), | |||||
| nearest_dep_node->NodeName().c_str(), | |||||
| node_item->NodeName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -57,14 +57,17 @@ class HybridModelBuilder { | |||||
| Status ValidateParams(); | Status ValidateParams(); | ||||
| Status LoadGraph(); | Status LoadGraph(); | ||||
| Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
| Status LoadTask(NodeItem &node_item); | |||||
| Status LoadTasks(); | Status LoadTasks(); | ||||
| Status IdentifyVariableOutputs(NodeItem &node_item); | Status IdentifyVariableOutputs(NodeItem &node_item); | ||||
| Status IdentifySameInputs(NodeItem &node_item); | Status IdentifySameInputs(NodeItem &node_item); | ||||
| Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
| Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
| Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | ||||
| Status CollectParallelGroups(NodeItem *node_item); | |||||
| Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | ||||
| Status ParseDependentForFusedSubgraph(NodeItem &node_item); | |||||
| Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies); | |||||
| Status ParseDependentByParallelGroup(); | |||||
| Status IndexTaskDefs(); | Status IndexTaskDefs(); | ||||
| Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); | Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); | ||||
| Status IndexSpecialNodes(); | Status IndexSpecialNodes(); | ||||
| @@ -97,12 +100,14 @@ class HybridModelBuilder { | |||||
| NodeItem *MutableNodeItem(const NodePtr &node); | NodeItem *MutableNodeItem(const NodePtr &node); | ||||
| GeRootModelPtr ge_root_model_; | GeRootModelPtr ge_root_model_; | ||||
| ComputeGraphPtr root_graph_; | |||||
| std::map<std::string, GeModelPtr> subgraph_models_; | std::map<std::string, GeModelPtr> subgraph_models_; | ||||
| std::map<std::string, NodePtr> constant_op_nodes_; | std::map<std::string, NodePtr> constant_op_nodes_; | ||||
| std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | |||||
| std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | |||||
| HybridModel &hybrid_model_; | HybridModel &hybrid_model_; | ||||
| std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_; | std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_; | ||||
| int node_index = 0; | |||||
| RuntimeParam &runtime_param_; | RuntimeParam &runtime_param_; | ||||
| VarManager *var_manager_ = nullptr; | VarManager *var_manager_ = nullptr; | ||||
| @@ -251,6 +251,10 @@ bool NodeItem::IsControlOp() const { | |||||
| return ge::hybrid::IsControlOp(op_desc->GetType()); | return ge::hybrid::IsControlOp(op_desc->GetType()); | ||||
| } | } | ||||
| bool NodeItem::IsHcclOp() const { | |||||
| return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL; | |||||
| } | |||||
| std::string NodeItem::DebugString() const { | std::string NodeItem::DebugString() const { | ||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << "Node: "; | ss << "Node: "; | ||||
| @@ -67,6 +67,8 @@ struct NodeItem { | |||||
| bool IsControlOp() const; | bool IsControlOp() const; | ||||
| bool IsHcclOp() const; | |||||
| void SetToDynamic(); | void SetToDynamic(); | ||||
| std::string DebugString() const; | std::string DebugString() const; | ||||
| @@ -95,13 +95,6 @@ Status KnownNodeTask::UpdateArgs(TaskContext &context) { | |||||
| Status KnownNodeTask::Init(TaskContext &context) { | Status KnownNodeTask::Init(TaskContext &context) { | ||||
| // allocate output mem | // allocate output mem | ||||
| GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed."); | GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed."); | ||||
| // init davinicmodel | |||||
| if (!load_flag_) { | |||||
| davinci_model_->InitRuntimeParams(); | |||||
| GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed."); | |||||
| } | |||||
| // allocate mem base | // allocate mem base | ||||
| void *buffer = nullptr; | void *buffer = nullptr; | ||||
| if (davinci_model_->TotalMemSize() != 0) { | if (davinci_model_->TotalMemSize() != 0) { | ||||
| @@ -126,30 +119,34 @@ Status KnownNodeTask::Init(TaskContext &context) { | |||||
| auto dump_properties = context.GetDumpProperties(); | auto dump_properties = context.GetDumpProperties(); | ||||
| if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { | if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { | ||||
| davinci_model_->SetDumpProperties(dump_properties); | davinci_model_->SetDumpProperties(dump_properties); | ||||
| void *global_step = nullptr; | |||||
| TensorValue *varible_global_step = context.GetVariable(NODE_NAME_GLOBAL_STEP); | |||||
| if (varible_global_step != nullptr) { | |||||
| global_step = varible_global_step->MutableData(); | |||||
| } | |||||
| void *global_step = context.GetExecutionContext()->global_step; | |||||
| davinci_model_->SetKnownShapeGlobalStep(global_step); | davinci_model_->SetKnownShapeGlobalStep(global_step); | ||||
| } | } | ||||
| int32_t device_id = 0; | |||||
| rtError_t rt_ret = rtGetDevice(&device_id); | |||||
| if (rt_ret != RT_ERROR_NONE || device_id < 0) { | |||||
| GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| davinci_model_->SetDeviceId(device_id); | |||||
| GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed."); | |||||
| load_flag_ = true; | load_flag_ = true; | ||||
| } else { | |||||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), | |||||
| davinci_model_->Id(), davinci_model_->SubModelId()), "KnownNodeTask::Init destroy aicpu kernel failed."); | |||||
| } | } | ||||
| GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), | |||||
| davinci_model_->Id(), davinci_model_->SubModelId()), | |||||
| "KnownNodeTask::Init destroy aicpu kernel failed."); | |||||
| GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName()); | GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status KnownNodeTask::InitDavinciModel() { | |||||
| GELOGD("[Init][Model] start"); | |||||
| davinci_model_->InitRuntimeParams(); | |||||
| GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed"); | |||||
| int32_t device_id = 0; | |||||
| GE_CHK_RT_RET(rtGetDevice(&device_id)); | |||||
| davinci_model_->SetDeviceId(static_cast<uint32_t>(device_id)); | |||||
| GE_CHK_STATUS_RET(DoInitDavinciModel(), "[Init][Model] Failed to init davinci model."); | |||||
| GELOGD("[Init][Model] success"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status KnownNodeTask::DoInitDavinciModel() { | |||||
| return davinci_model_->Init(); | |||||
| } | |||||
| Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { | Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { | ||||
| GELOGD("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName()); | GELOGD("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName()); | ||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start"); | ||||
| @@ -186,9 +183,11 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node | |||||
| GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); | GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); | ||||
| task = MakeShared<KnownNodeTask>(davinci_model); | |||||
| GE_CHECK_NOTNULL(task); | |||||
| auto known_node_task = MakeShared<KnownNodeTask>(davinci_model); | |||||
| GE_CHECK_NOTNULL(known_node_task); | |||||
| GE_CHK_STATUS_RET_NOLOG(known_node_task->InitDavinciModel()); | |||||
| GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str()); | GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str()); | ||||
| task = std::move(known_node_task); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -31,11 +31,15 @@ class KnownNodeTask : public NodeTask { | |||||
| : davinci_model_(davinci_model) | : davinci_model_(davinci_model) | ||||
| {} | {} | ||||
| ~KnownNodeTask() {} | |||||
| ~KnownNodeTask() = default; | |||||
| Status UpdateArgs(TaskContext &context) override; | Status UpdateArgs(TaskContext &context) override; | ||||
| Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | ||||
| Status Init(TaskContext &context) override; | Status Init(TaskContext &context) override; | ||||
| Status InitDavinciModel(); | |||||
| protected: | |||||
| virtual Status DoInitDavinciModel(); | |||||
| private: | private: | ||||
| std::shared_ptr<DavinciModel> davinci_model_ = nullptr; | std::shared_ptr<DavinciModel> davinci_model_ = nullptr; | ||||
| bool load_flag_ = false; | bool load_flag_ = false; | ||||
| @@ -47,8 +51,6 @@ class KnownNodeExecutor : public NodeExecutor { | |||||
| Status PrepareTask(NodeTask &task, TaskContext &context) const; | Status PrepareTask(NodeTask &task, TaskContext &context) const; | ||||
| Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | ||||
| ~KnownNodeExecutor() {} | ~KnownNodeExecutor() {} | ||||
| private: | |||||
| std::shared_ptr<DavinciModel> davinci_model_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -19,6 +19,9 @@ | |||||
| #include <mutex> | #include <mutex> | ||||
| #include <string> | #include <string> | ||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| namespace ge { | namespace ge { | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY SingleOpManager::~SingleOpManager() { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY SingleOpManager::~SingleOpManager() { | ||||
| for (auto &it : stream_resources_) { | for (auto &it : stream_resources_) { | ||||
| @@ -67,6 +70,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::Release | |||||
| delete it->second; | delete it->second; | ||||
| it->second = nullptr; | it->second = nullptr; | ||||
| (void)stream_resources_.erase(it); | (void)stream_resources_.erase(it); | ||||
| MemManager::Instance().CachingInstance(RT_MEMORY_HBM).TryFreeBlocks(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -44,19 +44,46 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const size_t kDataOutputNum = 1; | const size_t kDataOutputNum = 1; | ||||
| bool NeedHybridModel(GeModelPtr &ge_model) { | |||||
| Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||||
| auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||||
| GE_CHECK_NOTNULL(comp_graph); | |||||
| for (const auto &node : comp_graph->GetAllNodes()) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| const auto &depends = op_desc->GetOpInferDepends(); | |||||
| if (!depends.empty()) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | |||||
| bool infer_depend_flag = false; | |||||
| GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | auto tasks = ge_model->GetModelTaskDefPtr()->task(); | ||||
| int32_t kernel_task_num = 0; | int32_t kernel_task_num = 0; | ||||
| for (int i = 0; i < tasks.size(); ++i) { | for (int i = 0; i < tasks.size(); ++i) { | ||||
| auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type()); | auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type()); | ||||
| if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { | if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { | ||||
| kernel_task_num++; | |||||
| if (kernel_task_num > 1) { | |||||
| return true; | |||||
| const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() : | |||||
| tasks[i].kernel_with_handle().context(); | |||||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | |||||
| if (kernel_type == ccKernelType::TE) { | |||||
| if (infer_depend_flag) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| kernel_task_num++; | |||||
| if (kernel_task_num > 1) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return false; | |||||
| return SUCCESS; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -503,7 +530,9 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & | |||||
| auto ge_model = model_helper_.GetGeModel(); | auto ge_model = model_helper_.GetGeModel(); | ||||
| GE_CHECK_NOTNULL(ge_model); | GE_CHECK_NOTNULL(ge_model); | ||||
| if (NeedHybridModel(ge_model)) { | |||||
| bool need_hybrid_model = false; | |||||
| GE_CHK_STATUS_RET(NeedHybridModel(ge_model, need_hybrid_model), "[Check][NeedHybridModel] failed."); | |||||
| if (need_hybrid_model) { | |||||
| GELOGD("Build single op HybridModel."); | GELOGD("Build single op HybridModel."); | ||||
| GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); | GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); | ||||
| auto root_model = model_helper_.GetGeRootModel(); | auto root_model = model_helper_.GetGeRootModel(); | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit e68940202b874ccec77d621f59b34fc4404bede2 | |||||
| Subproject commit 0c4602a4615a9368b06633a5087e2114518f29ca | |||||
| @@ -1 +1 @@ | |||||
| Subproject commit b203d47837421b2c149f353fc0808f6a29fa584e | |||||
| Subproject commit d851e1d467768b6cefd8f5f44745be1c5312121a | |||||
| @@ -435,3 +435,7 @@ rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId) | |||||
| rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId) { | rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId) { | ||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtDebugUnRegisterForStream(rtStream_t stream) { | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| @@ -667,6 +667,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/merge_pass_unittest.cc" | "graph/passes/merge_pass_unittest.cc" | ||||
| #"graph/passes/switch_pass_unittest.cc" | #"graph/passes/switch_pass_unittest.cc" | ||||
| "graph/passes/switch_logic_remove_pass_unittest.cc" | "graph/passes/switch_logic_remove_pass_unittest.cc" | ||||
| "graph/passes/switch_dead_branch_elimination_unittest.cc" | |||||
| "graph/passes/assert_pass_unittest.cc" | "graph/passes/assert_pass_unittest.cc" | ||||
| "graph/passes/dropout_pass_unittest.cc" | "graph/passes/dropout_pass_unittest.cc" | ||||
| "graph/passes/unused_const_pass_unittest.cc" | "graph/passes/unused_const_pass_unittest.cc" | ||||
| @@ -731,6 +732,7 @@ set(KERNEL_TEST_FILES | |||||
| "graph/passes/folding_kernel/gather_v2_kernel_unittest.cc" | "graph/passes/folding_kernel/gather_v2_kernel_unittest.cc" | ||||
| "graph/passes/folding_kernel/slice_kernel_unittest.cc" | "graph/passes/folding_kernel/slice_kernel_unittest.cc" | ||||
| "graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc" | "graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc" | ||||
| "graph/passes/atomic_addr_clean_pass_unittest.cc" | |||||
| ) | ) | ||||
| set(MULTI_PARTS_TEST_FILES | set(MULTI_PARTS_TEST_FILES | ||||
| @@ -760,6 +762,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
| "graph/variable_accelerate_ctrl_unittest.cc" | "graph/variable_accelerate_ctrl_unittest.cc" | ||||
| "graph/build/logical_stream_allocator_unittest.cc" | "graph/build/logical_stream_allocator_unittest.cc" | ||||
| "graph/build/mem_assigner_unittest.cc" | "graph/build/mem_assigner_unittest.cc" | ||||
| "graph/build/task_generator_unittest.cc" | |||||
| "graph/preprocess/graph_preprocess_unittest.cc" | "graph/preprocess/graph_preprocess_unittest.cc" | ||||
| "graph/manager/hcom_util_unittest.cc" | "graph/manager/hcom_util_unittest.cc" | ||||
| "graph/manager/graph_caching_allocator_unittest.cc" | "graph/manager/graph_caching_allocator_unittest.cc" | ||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <memory> | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/attr_value.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| #include "../passes/graph_builder_utils.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/build/task_generator.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| using namespace ge; | |||||
| class UtestTaskGeneratorTest : public testing::Test { | |||||
| public: | |||||
| ge::ComputeGraphPtr BuildGraphFpProfiling() { | |||||
| ge::ut::GraphBuilder builder("graph"); | |||||
| auto data = builder.AddNode("data", "phony", 1, 1); | |||||
| auto addn1 = builder.AddNode("addn1", "AddN", 1, 1); | |||||
| auto netoutput = builder.AddNode("netoutput", "NetOutput", 2, 0); | |||||
| auto op_desc = data->GetOpDesc(); | |||||
| (void)AttrUtils::SetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "IteratorV2"); | |||||
| op_desc->SetOpKernelLibName("GE"); | |||||
| builder.AddDataEdge(data, 0, addn1, 0); | |||||
| builder.AddDataEdge(addn1, 0, netoutput, 0); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(UtestTaskGeneratorTest, AutoFindFpOpIndex) { | |||||
| auto graph = BuildGraphFpProfiling(); | |||||
| TaskGenerator task_generator(nullptr, 0); | |||||
| ProfilingPoint profiling_point; | |||||
| profiling_point.fp_index = -1; | |||||
| EXPECT_EQ(task_generator.AutoFindFpOpIndex(graph, profiling_point), SUCCESS); | |||||
| // addn1 is fp | |||||
| EXPECT_EQ(profiling_point.fp_index, 2); | |||||
| } | |||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include "graph/passes/atomic_addr_clean_pass.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "common/types.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/attr_value.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "inc/pass_manager.h" | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| class UtestGraphPassesAtomicAddrCleanPass : public Test { | |||||
| public: | |||||
| UtestGraphPassesAtomicAddrCleanPass() { | |||||
| graph_ = std::make_shared<ComputeGraph>("test"); | |||||
| } | |||||
| NodePtr NewNode(const string &name, const string &type, int input_cnt, int output_cnt) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < input_cnt; ++i) { | |||||
| op_desc->AddInputDesc(GeTensorDesc()); | |||||
| } | |||||
| for (int i = 0; i < output_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(GeTensorDesc()); | |||||
| } | |||||
| NodePtr node = graph_->AddNode(op_desc); | |||||
| return node; | |||||
| } | |||||
| ComputeGraphPtr graph_; | |||||
| }; | |||||
| // node1 -> node2 -> node3 | |||||
| TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { | |||||
| auto node1 = NewNode("node1", DATA, 0, 1); | |||||
| auto node2 = NewNode("node2", RELU, 1, 1); | |||||
| auto node3 = NewNode("node3", NETOUTPUT, 1, 0); | |||||
| GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); | |||||
| AtomicAddrCleanPass atomi_addr_clean_pass; | |||||
| Status ret = atomi_addr_clean_pass.Run(graph_); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,163 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cstdint> | |||||
| #include <string> | |||||
| #include <gtest/gtest.h> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "graph/passes/switch_dead_branch_elimination.h" | |||||
| #include "graph_builder_utils.h" | |||||
| namespace ge { | |||||
| class UtestSwitchDeadBranchElimination : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| namespace { | |||||
| /* | |||||
| * data1 const1 | |||||
| * \ / | |||||
| * case1 | |||||
| * | | |||||
| * relu1 | |||||
| * | | |||||
| * netoutput | |||||
| */ | |||||
| ut::GraphBuilder ParentGraphBuilder() { | |||||
| ut::GraphBuilder builder = ut::GraphBuilder("g1"); | |||||
| auto data1 = builder.AddNode("data1", "Data", 0, 1); | |||||
| auto const1 = builder.AddNode("const1", "Const", 0, 1); | |||||
| auto case1 = builder.AddNode("case1", CASE, 2, 1); | |||||
| auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); | |||||
| auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
| int32_t weight[1] = {1}; | |||||
| GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); | |||||
| GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight)); | |||||
| OpDescUtils::SetWeights(const1, {tensor}); | |||||
| builder.AddDataEdge(data1, 0, case1, 0); | |||||
| builder.AddDataEdge(const1, 0, case1, 1); | |||||
| builder.AddDataEdge(case1, 0, relu1, 0); | |||||
| builder.AddDataEdge(relu1, 0, netoutput, 0); | |||||
| return builder; | |||||
| } | |||||
| /* | |||||
| * data1 data2 | |||||
| * \ / | |||||
| * switch | |||||
| * / \ | |||||
| * relu1 relu2 | |||||
| * \ / | |||||
| * merge | |||||
| * | | |||||
| * netoutput | |||||
| */ | |||||
| ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) { | |||||
| ut::GraphBuilder builder = ut::GraphBuilder(graph_name); | |||||
| string data1_name = "data1_" + std::to_string(num); | |||||
| auto data1 = builder.AddNode(data1_name, "Data", 0, 1); | |||||
| auto data1_desc = data1->GetOpDesc(); | |||||
| EXPECT_NE(data1_desc, nullptr); | |||||
| AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); | |||||
| string data2_name = "data2_" + std::to_string(num); | |||||
| auto data2 = builder.AddNode(data2_name, "Data", 0, 1); | |||||
| auto data2_desc = data2->GetOpDesc(); | |||||
| EXPECT_NE(data2_desc, nullptr); | |||||
| AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); | |||||
| string switch_name = "switch_" + std::to_string(num); | |||||
| auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); | |||||
| string relu1_name = "relu1_" + std::to_string(num); | |||||
| auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1); | |||||
| string relu2_name = "relu2_" + std::to_string(num); | |||||
| auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1); | |||||
| string merge_name = "merge_" + std::to_string(num); | |||||
| auto merge = builder.AddNode(merge_name, "Merge", 2, 1); | |||||
| string output_name = "output_" + std::to_string(num); | |||||
| auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0); | |||||
| builder.AddDataEdge(data1, 0, switch1, 0); | |||||
| builder.AddDataEdge(data2, 0, switch1, 1); | |||||
| builder.AddDataEdge(switch1, 0, relu1, 0); | |||||
| builder.AddDataEdge(switch1, 1, relu2, 0); | |||||
| builder.AddDataEdge(relu1, 0, merge, 0); | |||||
| builder.AddDataEdge(relu2, 0, merge, 1); | |||||
| builder.AddDataEdge(merge, 0, netoutput, 0); | |||||
| return builder; | |||||
| } | |||||
| void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { | |||||
| auto case_node = parent_graph->FindNode("case1"); | |||||
| EXPECT_NE(case_node, nullptr); | |||||
| for (uint32_t i = 0; i < branch_num; ++i) { | |||||
| string name = "Branch_Graph_" + std::to_string(i); | |||||
| auto builder_subgraph = SwitchSubgraphBuilder(name, i); | |||||
| auto switch_subgraph = builder_subgraph.GetGraph(); | |||||
| case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); | |||||
| case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); | |||||
| switch_subgraph->SetParentNode(case_node); | |||||
| switch_subgraph->SetParentGraph(parent_graph); | |||||
| EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| TEST_F(UtestSwitchDeadBranchElimination, switch_dead_branch_elimination_across_case_success) { | |||||
| auto builder = ParentGraphBuilder(); | |||||
| auto parent_graph = builder.GetGraph(); | |||||
| AddCaseSubgraph(parent_graph, 2); | |||||
| auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
| EXPECT_EQ(subgraphs.size(), 2); | |||||
| SwitchDeadBranchElimination switch_pass; | |||||
| for (auto &subgraph : subgraphs) { | |||||
| auto switch_node = subgraph->FindFirstNodeMatchType("Switch"); | |||||
| if (switch_node != nullptr) { | |||||
| EXPECT_EQ(switch_pass.Run(switch_node), SUCCESS); | |||||
| } | |||||
| } | |||||
| auto all_nodes = parent_graph->GetAllNodes(); | |||||
| EXPECT_EQ(all_nodes.size(), 17); | |||||
| for (auto &subgraph : subgraphs) { | |||||
| EXPECT_EQ(subgraph->GetDirectNode().size(), 6); | |||||
| EXPECT_EQ(subgraph->FindFirstNodeMatchType("Switch"), nullptr); | |||||
| auto merge_node = subgraph->FindFirstNodeMatchType("Merge"); | |||||
| EXPECT_NE(merge_node, nullptr); | |||||
| auto merge_innode = merge_node->GetInDataNodes(); | |||||
| EXPECT_EQ(merge_innode.size(), 1); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -30,6 +30,7 @@ | |||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hybrid/executor/hybrid_model_executor.h" | |||||
| #include "hybrid/node_executor/aicore/aicore_task_builder.h" | #include "hybrid/node_executor/aicore/aicore_task_builder.h" | ||||
| #include "graph/load/model_manager/tbe_handle_store.h" | #include "graph/load/model_manager/tbe_handle_store.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| @@ -242,4 +243,16 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||||
| ge_sub_model->SetWeight(weight_buffer); | ge_sub_model->SetWeight(weight_buffer); | ||||
| ret = hybrid_model_builder.InitWeights(); | ret = hybrid_model_builder.InitWeights(); | ||||
| ASSERT_EQ(ret,PARAM_INVALID); | ASSERT_EQ(ret,PARAM_INVALID); | ||||
| } | |||||
| } | |||||
| TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | |||||
| GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||||
| HybridModel model(root_model); | |||||
| HybridModel *model_ptr = &model; | |||||
| uint32_t device_id = 0; | |||||
| rtStream_t stream; | |||||
| HybridModelExecutor executor(model_ptr, device_id, stream); | |||||
| executor.Init(); | |||||
| } | |||||