From 87224a512a85be3b434ce9b20d0cca1dbdaebb8a Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Tue, 22 Dec 2020 16:14:56 +0800 Subject: [PATCH] GeTensor aligned addr & zero copy support --- ge/CMakeLists.txt | 4 + ge/executor/CMakeLists.txt | 1 + ge/executor/module.mk | 1 + ge/ge_inference.mk | 2 + ge/ge_local_engine/engine/host_cpu_engine.cc | 28 ++++ ge/ge_runner.mk | 2 + ge/graph/manager/graph_manager.cc | 14 ++ ge/graph/manager/graph_mem_allocator.cc | 18 ++- ge/graph/manager/graph_mem_allocator.h | 10 +- ge/graph/manager/host_mem_allocator.cc | 70 ++++++++++ ge/graph/manager/host_mem_allocator.h | 58 +++++++++ ge/graph/manager/host_mem_manager.cc | 18 ++- ge/graph/manager/host_mem_manager.h | 4 + ge/graph/passes/assign_pass.cc | 122 +++++++++++++++++- ge/graph/passes/assign_pass.h | 15 +++ ge/graph/passes/constant_fuse_same_pass.cc | 16 ++- ge/graph/passes/constant_fuse_same_pass.h | 17 ++- ge/graph/passes/inplace_support_check_pass.cc | 86 ++++++++++++ ge/graph/passes/inplace_support_check_pass.h | 28 ++++ .../passes/switch_to_stream_switch_pass.cc | 2 +- ge/graph/preprocess/graph_preprocess.cc | 6 +- ge/hybrid/common/npu_memory_allocator.cc | 11 ++ ge/hybrid/model/hybrid_model_builder.cc | 34 ++++- .../host_cpu/host_cpu_node_executor.cc | 22 +++- 24 files changed, 569 insertions(+), 20 deletions(-) create mode 100644 ge/graph/manager/host_mem_allocator.cc create mode 100644 ge/graph/manager/host_mem_allocator.h create mode 100644 ge/graph/passes/inplace_support_check_pass.cc create mode 100644 ge/graph/passes/inplace_support_check_pass.h diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 3da80492..9391f223 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -125,6 +125,7 @@ set(TRAIN_SRC_LIST "graph/manager/graph_var_manager.cc" "graph/manager/host_mem_manager.cc" "graph/manager/rdma_pool_allocator.cc" + $<$:graph/manager/host_mem_allocator.cc> "graph/manager/memory_api.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/trans_var_data_utils.cc" @@ -166,6 +167,7 @@ set(TRAIN_SRC_LIST "graph/passes/hccl_group_pass.cc" "graph/passes/enter_pass.cc" "graph/passes/assign_pass.cc" + "graph/passes/inplace_support_check_pass.cc" "graph/passes/flow_ctrl_pass.cc" "graph/passes/global_step_insert_pass.cc" "host_kernels/transpose_kernel.cc" @@ -401,6 +403,7 @@ set(INFER_SRC_LIST "graph/manager/graph_var_manager.cc" "graph/manager/host_mem_manager.cc" "graph/manager/rdma_pool_allocator.cc" + $<$:graph/manager/host_mem_allocator.cc> "graph/manager/graph_mem_allocator.cc" "graph/manager/graph_caching_allocator.cc" "model/ge_model.cc" @@ -522,6 +525,7 @@ set(INFER_SRC_LIST "graph/passes/for_pass.cc" "graph/passes/enter_pass.cc" "graph/passes/assign_pass.cc" + "graph/passes/inplace_support_check_pass.cc" "graph/passes/addn_pass.cc" "graph/passes/common_subexpression_elimination_pass.cc" "graph/passes/remove_same_const_pass.cc" diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index d59afd03..8e6236f9 100644 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -28,6 +28,7 @@ set(SRC_LIST "../graph/manager/trans_var_data_utils.cc" "../graph/manager/util/debug.cc" "../graph/manager/rdma_pool_allocator.cc" + $<$:../graph/manager/host_mem_allocator.cc> "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" "../model/ge_model.cc" "../model/ge_root_model.cc" diff --git a/ge/executor/module.mk b/ge/executor/module.mk index 34c2a37e..87abdade 100644 --- a/ge/executor/module.mk +++ b/ge/executor/module.mk @@ -15,6 +15,7 @@ local_ge_executor_src_files := \ ../graph/manager/graph_manager_utils.cc \ ../graph/manager/graph_var_manager.cc \ ../graph/manager/rdma_pool_allocator.cc \ + ../graph/manager/host_mem_allocator.cc \ ../graph/manager/graph_mem_allocator.cc \ ../graph/manager/graph_caching_allocator.cc \ ../graph/manager/trans_var_data_utils.cc \ diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index e20456d5..b4f6f1e1 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -64,6 +64,7 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ graph/manager/graph_var_manager.cc \ graph/manager/host_mem_manager.cc \ graph/manager/rdma_pool_allocator.cc \ + graph/manager/host_mem_allocator.cc \ graph/manager/graph_mem_allocator.cc \ graph/manager/graph_caching_allocator.cc \ @@ -196,6 +197,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/for_pass.cc \ graph/passes/enter_pass.cc \ graph/passes/assign_pass.cc \ + graph/passes/inplace_support_check_pass.cc \ graph/passes/addn_pass.cc \ graph/passes/common_subexpression_elimination_pass.cc \ graph/passes/transop_symmetry_elimination_pass.cc \ diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc index e17f73de..5bcff7ff 100755 --- a/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -26,6 +26,33 @@ #include "common/math/math_util.h" namespace { +#if (ENABLE_OPEN_SRC != True) +#define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ + case (DTYPE): { \ + GeTensorPtr ge_tensor = nullptr; \ + if (need_create_flag) { \ + uint64_t size = data_num * sizeof(TYPE); \ + ge_tensor = MakeShared(out_desc, size); \ + GE_CHECK_NOTNULL(ge_tensor); \ + GELOGD("node:%s allocate output %zu success, size=%lld", op_desc->GetName().c_str(), i, size); \ + ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ + ge_tensor->MutableTensorDesc().SetShape(out_desc.GetShape()); \ + outputs.emplace_back(ge_tensor); \ + } else { \ + ge_tensor = outputs[i]; \ + GE_CHECK_NOTNULL(ge_tensor); \ + GELOGD("node:%s existed output %zu", op_desc->GetName().c_str(), i); \ + } \ + auto tensor = TensorAdapter::AsTensor(*ge_tensor); \ + auto tensor_name = op_desc->GetOutputNameByIndex(i); \ + GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu", \ + op_desc->GetName().c_str(), i); \ + GELOGD("Successfully inserted output tensor. node = %s, index = %zu, output name = %s, addr = %p, size = %zu", \ + op_desc->GetName().c_str(), i, tensor_name.c_str(), tensor.GetData(), tensor.GetSize()); \ + named_outputs.emplace(tensor_name, tensor); \ + break; \ + } +#else #define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ case (DTYPE): { \ GeTensorPtr ge_tensor = nullptr; \ @@ -61,6 +88,7 @@ namespace { named_outputs.emplace(tensor_name, tensor); \ break; \ } +#endif } namespace ge { diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 9706dadb..41956493 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -94,6 +94,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/manager/graph_var_manager.cc \ graph/manager/host_mem_manager.cc \ graph/manager/rdma_pool_allocator.cc \ + graph/manager/host_mem_allocator.cc \ graph/manager/memory_api.cc \ graph/manager/model_manager/event_manager.cc \ graph/manager/trans_var_data_utils.cc \ @@ -135,6 +136,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/hccl_group_pass.cc \ graph/passes/enter_pass.cc \ graph/passes/assign_pass.cc \ + graph/passes/inplace_support_check_pass.cc \ graph/passes/flow_ctrl_pass.cc \ graph/passes/global_step_insert_pass.cc \ host_kernels/transpose_kernel.cc \ diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 46799ba3..a9ee3570 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -38,6 +38,10 @@ #include "graph/partition/stage_partition.h" #include "graph/passes/addn_pass.h" #include "graph/passes/bitcast_pass.h" +#if (ENABLE_OPEN_SRC != True) +#include "graph/passes/assign_pass.h" +#include "graph/passes/inplace_support_check_pass.h" +#endif #include "graph/passes/atomic_addr_clean_pass.h" #include "graph/passes/attach_stream_label_pass.h" #include "graph/passes/cast_remove_pass.h" @@ -2237,10 +2241,20 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { ReshapeRemovePass reshape_remove_pass; CondRemovePass condition_remove_pass; BitcastPass bitcast_pass; +#if (ENABLE_OPEN_SRC != True) + AssignPass assign_pass; + InplaceSupportCheckPass inplace_support_check_pass; +#endif names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); names_to_passes.emplace_back("BitcastPass", &bitcast_pass); +#if (ENABLE_OPEN_SRC != True) + if (GetContext().GetHostExecFlag()) { + names_to_passes.emplace_back("AssignPass", &assign_pass); + names_to_passes.emplace_back("InplaceSupportCheckPass", &inplace_support_check_pass); + } +#endif GE_TIMESTAMP_START(names_to_passes); ret = GEPass(compute_graph).Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "OptimizeStage2::MergedGraphNameToPasses"); diff --git a/ge/graph/manager/graph_mem_allocator.cc b/ge/graph/manager/graph_mem_allocator.cc index 7ee7df20..92782414 100755 --- a/ge/graph/manager/graph_mem_allocator.cc +++ b/ge/graph/manager/graph_mem_allocator.cc @@ -19,7 +19,9 @@ #include #include "graph/manager/graph_caching_allocator.h" #include "graph/manager/rdma_pool_allocator.h" - +#if (ENABLE_OPEN_SRC != True) +#include "graph/manager/host_mem_allocator.h" +#endif namespace ge { void MemoryAllocator::Initialize(uint32_t device_id) { GELOGI("MemoryAllocator::Initialize"); @@ -190,6 +192,12 @@ Status MemManager::Initialize(const std::vector &memory_type) { GELOGE(ge::INTERNAL_ERROR, "Create RdmaAllocator failed."); return ge::INTERNAL_ERROR; } +#if (ENABLE_OPEN_SRC != True) + if (InitAllocator(memory_type, host_allocator_map_) != SUCCESS) { + GELOGE(ge::INTERNAL_ERROR, "Create HostMemAllocator failed."); + return ge::INTERNAL_ERROR; + } +#endif return SUCCESS; } @@ -211,6 +219,9 @@ void MemManager::Finalize() noexcept { // caching and rdma allocator use memory allocator, so finalize them first FinalizeAllocatorMap(caching_allocator_map_); FinalizeAllocatorMap(rdma_allocator_map_); +#if (ENABLE_OPEN_SRC != True) + FinalizeAllocatorMap(host_allocator_map_); +#endif FinalizeAllocatorMap(memory_allocator_map_); } @@ -239,4 +250,9 @@ CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { RdmaPoolAllocator &MemManager::RdmaPoolInstance(rtMemType_t memory_type) { return Instance().GetAllocator(memory_type, rdma_allocator_map_); } +#if (ENABLE_OPEN_SRC != True) +HostMemAllocator &MemManager::HostMemInstance(rtMemType_t memory_type) { + return Instance().GetAllocator(memory_type, host_allocator_map_); +} +#endif } // namespace ge diff --git a/ge/graph/manager/graph_mem_allocator.h b/ge/graph/manager/graph_mem_allocator.h index 2723ae5c..5cca5854 100644 --- a/ge/graph/manager/graph_mem_allocator.h +++ b/ge/graph/manager/graph_mem_allocator.h @@ -139,7 +139,9 @@ class MemoryAllocator { using MemoryAllocatorPtr = std::shared_ptr; class CachingAllocator; class RdmaPoolAllocator; - +#if (ENABLE_OPEN_SRC != True) +class HostMemAllocator; +#endif class MemManager { public: MemManager(); @@ -148,6 +150,9 @@ class MemManager { static MemoryAllocator *Instance(rtMemType_t memory_type); CachingAllocator &CachingInstance(rtMemType_t memory_type); RdmaPoolAllocator &RdmaPoolInstance(rtMemType_t memory_type); +#if (ENABLE_OPEN_SRC != True) + HostMemAllocator &HostMemInstance(rtMemType_t memory_type); +#endif MemManager(const MemManager &) = delete; MemManager &operator=(const MemManager &) = delete; /// @@ -235,6 +240,9 @@ class MemManager { std::map memory_allocator_map_; std::map caching_allocator_map_; std::map rdma_allocator_map_; +#if (ENABLE_OPEN_SRC != True) + std::map host_allocator_map_; +#endif std::recursive_mutex allocator_mutex_; }; } // namespace ge diff --git a/ge/graph/manager/host_mem_allocator.cc b/ge/graph/manager/host_mem_allocator.cc new file mode 100644 index 00000000..d0417e91 --- /dev/null +++ b/ge/graph/manager/host_mem_allocator.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/manager/host_mem_allocator.h" +#include "framework/common/debug/ge_log.h" +#include "common/ge/ge_util.h" + +namespace ge { +const void *HostMemAllocator::Malloc(const std::shared_ptr &aligned_ptr, size_t size) { + if (aligned_ptr == nullptr) { + GELOGW("Insert a null aligned_ptr"); + return nullptr; + } + GELOGD("allocate existed host memory succ, addr=%p, size=%zu", aligned_ptr->Get(), size); + allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr }; + return aligned_ptr->Get(); +} + +uint8_t *HostMemAllocator::Malloc(size_t size) { + GELOGD("start to malloc host memory, size=%zu", size); + std::lock_guard lock(mutex_); + std::shared_ptr aligned_ptr = MakeShared(size); + if (aligned_ptr == nullptr) { + GELOGE(INTERNAL_ERROR, "make shared_ptr for AlignedPtr failed"); + return nullptr; + } + allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr }; + GELOGD("allocate host memory succ, addr=%p, size=%zu", aligned_ptr->Get(), size); + return aligned_ptr->MutableGet(); +} + +Status HostMemAllocator::Free(const void *memory_addr) { + GELOGD("Free host memory, addr=%p", memory_addr); + if (memory_addr == nullptr) { + GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer"); + return GE_GRAPH_FREE_FAILED; + } + + std::lock_guard lock(mutex_); + auto it = allocated_blocks_.find(memory_addr); + if (it == allocated_blocks_.end()) { + GELOGE(PARAM_INVALID, "Invalid memory pointer"); + return PARAM_INVALID; + } + it->second.second.reset(); + allocated_blocks_.erase(it); + + return SUCCESS; +} + +void HostMemAllocator::Clear() { + for (auto &block : allocated_blocks_) { + block.second.second.reset(); + } + allocated_blocks_.clear(); +} +} // namespace ge diff --git a/ge/graph/manager/host_mem_allocator.h b/ge/graph/manager/host_mem_allocator.h new file mode 100644 index 00000000..2138da63 --- /dev/null +++ b/ge/graph/manager/host_mem_allocator.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_MANAGER_HOST_MEM_ALLOCATOR_H_ +#define GE_GRAPH_MANAGER_HOST_MEM_ALLOCATOR_H_ + +#include +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "graph/aligned_ptr.h" +#include "runtime/mem.h" + +namespace ge { +class HostMemAllocator { + public: + explicit HostMemAllocator(rtMemType_t memory_type) : memory_type_(memory_type) {} + ~HostMemAllocator() = default; + + HostMemAllocator(const HostMemAllocator &) = delete; + HostMemAllocator &operator=(const HostMemAllocator &) = delete; + + Status Initialize() { + Clear(); + return SUCCESS; + } + void Finalize() { Clear(); } + + const void *Malloc(const std::shared_ptr& aligned_ptr, size_t size); + uint8_t *Malloc(size_t size); + Status Free(const void *memory_addr); + + std::pair> GetAlignedPtr(const void *addr) { return allocated_blocks_[addr]; } + + private: + void Clear(); + + rtMemType_t memory_type_; + std::unordered_map>> allocated_blocks_; + // lock around all operations + mutable std::mutex mutex_; +}; +} // namespace ge + +#endif // GE_GRAPH_MANAGER_HOST_MEM_ALLOCATOR_H_ diff --git a/ge/graph/manager/host_mem_manager.cc b/ge/graph/manager/host_mem_manager.cc index c99c9e87..1b530938 100644 --- a/ge/graph/manager/host_mem_manager.cc +++ b/ge/graph/manager/host_mem_manager.cc @@ -43,16 +43,32 @@ Status SharedMemAllocator::Allocate(SharedMemInfo &mem_info) { return GE_GRAPH_MEMORY_ALLOC_FAILED; } mem_info.fd = output_para.fd; +#if (ENABLE_OPEN_SRC != True) + mem_info.host_aligned_ptr = AlignedPtr::BuildAlignedPtr(mem_info.mem_size, + [&output_para](std::unique_ptr &ptr) { + GELOGD("set aligned_ptr, addr=%p", output_para.ptr); + ptr.reset(reinterpret_cast(output_para.ptr)); + }, + [](uint8_t *ptr) { + GELOGD("reset aligned_ptr in SharedMemAllocator, addr=%p", ptr); + ptr = nullptr; + }, 0); +#else mem_info.host_address = reinterpret_cast(output_para.ptr); +#endif mem_info.device_address = reinterpret_cast(output_para.devPtr); return SUCCESS; } Status SharedMemAllocator::DeAllocate(SharedMemInfo &mem_info) { GELOGD("SharedMemAllocator::DeAllocate"); +#if (ENABLE_OPEN_SRC != True) + rtFreeHostSharedMemoryIn free_para = {mem_info.shm_name.c_str(), mem_info.mem_size, mem_info.fd, + mem_info.host_aligned_ptr->MutableGet(), mem_info.device_address}; +#else rtFreeHostSharedMemoryIn free_para = {mem_info.shm_name.c_str(), mem_info.mem_size, mem_info.fd, mem_info.host_address, mem_info.device_address}; - +#endif rtError_t rt_ret = rtFreeHostSharedMemory(&free_para); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api(rtFreeHostSharedMemory) failed, ret: 0x%X.", rt_ret); diff --git a/ge/graph/manager/host_mem_manager.h b/ge/graph/manager/host_mem_manager.h index 66bd5826..72b65b27 100644 --- a/ge/graph/manager/host_mem_manager.h +++ b/ge/graph/manager/host_mem_manager.h @@ -42,7 +42,11 @@ struct SharedMemInfo { uint64_t mem_size = 0; int fd = 0; uint8_t *device_address = nullptr; +#if !ENABLE_OPEN_SRC + std::shared_ptr host_aligned_ptr = nullptr; +#else uint8_t *host_address = nullptr; +#endif SharedMemInfo() = default; SharedMemInfo(string name, uint64_t size) : op_name(std::move(name)), mem_size(size) {} }; diff --git a/ge/graph/passes/assign_pass.cc b/ge/graph/passes/assign_pass.cc index bb7a0f04..7ffc397c 100644 --- a/ge/graph/passes/assign_pass.cc +++ b/ge/graph/passes/assign_pass.cc @@ -15,8 +15,6 @@ */ #include "graph/passes/assign_pass.h" - -#include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "graph/utils/graph_utils.h" #include "graph/debug/ge_attr_define.h" @@ -28,6 +26,124 @@ const int32_t kAssignValueInputIndex = 1; } namespace ge { +#if (ENABLE_OPEN_SRC != True) +Status AssignPass::Run(NodePtr &node) { + GELOGD("AssignPass running"); + + if (TransformAttr(node) != SUCCESS) { + GELOGE(FAILED, "Transform assign_var_name attr failed, node=%s", node->GetName().c_str()); + return FAILED; + } + + if (node->GetType() == ASSIGN) { + if (OptimizedAssignNode(node) != SUCCESS) { + GELOGE(FAILED, "Optimize for assign_node %s failed", node->GetName().c_str()); + return FAILED; + } + } + + GELOGD("AssignPass success"); + return SUCCESS; +} + +/// +/// @brief Optimize for assign_node +/// @param [in] assign_node +/// @return Status +/// +Status AssignPass::OptimizedAssignNode(NodePtr &assign_node) { + const auto &ref_in_anchor = assign_node->GetInDataAnchor(kAssignRefInputIndex); + const auto &value_in_anchor = assign_node->GetInDataAnchor(kAssignValueInputIndex); + if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) { + GELOGE(FAILED, "In data anchor is null, node:%s", assign_node->GetName().c_str()); + return FAILED; + } + const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor(); + const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor(); + if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) { + GELOGE(FAILED, "Peer data anchor is null, node:%s", assign_node->GetName().c_str()); + return FAILED; + } + + if (IsCondMatch(assign_node, ref_peer_anchor, value_peer_anchor)) { + /// + /// variable not-const not-const + /// \ / | + /// \ / | + /// Assign ----> variable + /// | | + /// | | + /// node node + /// + GELOGD("Optimization for assign_node %s start", assign_node->GetName().c_str()); + if (IsolateAndDeleteNode(assign_node, {kAssignRefInputIndex}) != SUCCESS) { + GELOGE(FAILED, "Isolate and delete assign_node %s failed.", assign_node->GetName().c_str()); + return FAILED; + } + AddNodeDeleted(assign_node); + + const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc(); + const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc(); + if ((ref_input == nullptr) || (value_input == nullptr)) { + GELOGE(FAILED, "value input is null"); + return FAILED; + } + + // variable has and only has one input + if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str()); + return FAILED; + } + if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str()); + return FAILED; + } + + GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s", + value_input->GetName().c_str(), ref_input->GetName().c_str()); + if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME, + ref_input->GetName())) { + GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed."); + return FAILED; + } + auto value_node = value_peer_anchor->GetOwnerNode(); + AddRePassNode(value_node); + } + return SUCCESS; +} + +/// +/// @brief Transform assign_var_name attr +/// @param [in] node +/// @return Status +/// +Status AssignPass::TransformAttr(NodePtr &node) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDesc()) { + int32_t inplace_input_idx = -1; + std::string assign_var_name; + if (AttrUtils::GetInt(output_desc, INPLACE_SUPPORT_INPUT_INDEX, inplace_input_idx) && + AttrUtils::GetStr(output_desc, ASSIGN_VAR_NAME, assign_var_name)) { + GELOGD("Transform attr ASSIGN_VAR_NAME on node %s, assign_var_name=%s, inplace_input_idx=%d, ", + node->GetName().c_str(), assign_var_name.c_str(), inplace_input_idx); + const auto &in_data_anchor = node->GetInDataAnchor(inplace_input_idx); + GE_CHECK_NOTNULL(in_data_anchor); + const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_data_anchor); + auto in_node = peer_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(in_node->GetOpDesc()); + GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s", in_node->GetName().c_str(), assign_var_name.c_str()); + if (!AttrUtils::SetStr(in_node->GetOpDesc()->MutableOutputDesc(peer_data_anchor->GetIdx()), + ASSIGN_VAR_NAME, assign_var_name)) { + GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed."); + return FAILED; + } + AddRePassNode(in_node); + } + } + return SUCCESS; +} +#else Status AssignPass::Run(NodePtr &node) { GELOGD("AssignPass running"); if (node->GetType() != ASSIGN) { @@ -91,7 +207,7 @@ Status AssignPass::Run(NodePtr &node) { GELOGD("AssignPass success"); return SUCCESS; } - +#endif /// /// @brief Check if need optimize for assign_node /// @param [in] assign_node diff --git a/ge/graph/passes/assign_pass.h b/ge/graph/passes/assign_pass.h index 11cf1073..349da52e 100644 --- a/ge/graph/passes/assign_pass.h +++ b/ge/graph/passes/assign_pass.h @@ -25,6 +25,21 @@ class AssignPass : public BaseNodePass { Status Run(NodePtr &node) override; private: +#if (ENABLE_OPEN_SRC != True) + /// + /// @brief Optimize for assign_node + /// @param [in] assign_node + /// @return Status + /// + Status OptimizedAssignNode(NodePtr &assign_node); + + /// + /// @brief Transform assign_var_name attr + /// @param [in] node + /// @return Status + /// + Status TransformAttr(NodePtr &node); +#endif /// /// @brief Check if need optimize for assign_node /// @param [in] assign_node diff --git a/ge/graph/passes/constant_fuse_same_pass.cc b/ge/graph/passes/constant_fuse_same_pass.cc index d0970c59..ec5efcb4 100644 --- a/ge/graph/passes/constant_fuse_same_pass.cc +++ b/ge/graph/passes/constant_fuse_same_pass.cc @@ -19,13 +19,7 @@ #include #include #include -#include #include - -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "graph/debug/ge_attr_define.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" @@ -121,11 +115,21 @@ void ConstantFuseSamePass::GetFuseConstNodes(ComputeGraphPtr &graph, TypeUtils::DataTypeToSerialString(data_type).c_str()); continue; } +#if (ENABLE_OPEN_SRC != True) + if ((type_size != 0) && (weight->MutableData().GetAlignedPtr() == nullptr)) { + GELOGW("aligned_ptr is null while size is not 0"); + continue; + } +#endif ++insert_const_nums; SameConstKey map_key; map_key.data_size = type_size; +#if (ENABLE_OPEN_SRC != True) + map_key.aligned_ptr = weight->MutableData().GetAlignedPtr(); +#else map_key.data = weight->GetData().GetData(); +#endif map_key.data_type = data_type; map_key.format = output_tensor->GetFormat(); map_key.shape = output_tensor->GetShape().GetDims(); diff --git a/ge/graph/passes/constant_fuse_same_pass.h b/ge/graph/passes/constant_fuse_same_pass.h index 4935da84..605e10b3 100755 --- a/ge/graph/passes/constant_fuse_same_pass.h +++ b/ge/graph/passes/constant_fuse_same_pass.h @@ -21,14 +21,20 @@ #include #include #include - +#if (ENABLE_OPEN_SRC != True) +#include "graph/aligned_ptr.h" +#endif #include "graph/types.h" #include "inc/graph_pass.h" namespace ge { struct SameConstKey { int data_size; +#if (ENABLE_OPEN_SRC != True) + std::shared_ptr aligned_ptr; +#else const uint8_t *data; +#endif DataType data_type; Format format; std::vector shape; @@ -38,10 +44,19 @@ struct SameConstKey { if (data_size != key.data_size) { return data_size < key.data_size; } +#if (ENABLE_OPEN_SRC != True) + if (data_size != 0) { + int ret = memcmp(aligned_ptr->Get(), key.aligned_ptr->Get(), data_size); + if (ret != 0) { + return ret < 0; + } + } +#else int ret = memcmp(data, key.data, data_size); if (ret != 0) { return ret < 0; } +#endif if (data_type != key.data_type) { return data_type < key.data_type; } diff --git a/ge/graph/passes/inplace_support_check_pass.cc b/ge/graph/passes/inplace_support_check_pass.cc new file mode 100644 index 00000000..06986677 --- /dev/null +++ b/ge/graph/passes/inplace_support_check_pass.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/inplace_support_check_pass.h" +#include "framework/common/debug/log.h" +#include "graph/utils/graph_utils.h" +#include "graph/debug/ge_attr_define.h" + +namespace { +const uint32_t kInplaceSupportOutputIndex = 0; +const uint32_t kInplaceSupportOutputNum = 1; +static const std::set src_node_types = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA, + ge::CONSTANT, ge::CONSTANTOP, + ge::VARIABLE, ge::VARIABLEV2 }; +} + +namespace ge { +Status InplaceSupportCheckPass::Run(NodePtr &node) { + GELOGD("InplaceSupportCheckPass running"); + if (src_node_types.count(node->GetType()) > 0) { + GELOGD("meet src_node %s, skip InplaceSupportCheckPass", node->GetName().c_str()); + return SUCCESS; + } + if (node->GetAllOutDataAnchorsSize() != kInplaceSupportOutputNum) { + GELOGD("output num of node %s is not %u, skip InplaceSupportCheckPass", + node->GetName().c_str(), kInplaceSupportOutputNum); + return SUCCESS; + } + GE_CHECK_NOTNULL(node->GetOpDesc()); + const DataType &output_type = node->GetOpDesc()->GetOutputDesc(kInplaceSupportOutputIndex).GetDataType(); + const GeShape &output_shape = node->GetOpDesc()->GetOutputDesc(kInplaceSupportOutputIndex).GetShape(); + GELOGD("process InplaceSupportCheckPass on node %s", node->GetName().c_str()); + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_data_anchor == nullptr) { + continue; + } + auto in_node = peer_data_anchor->GetOwnerNode(); + if (src_node_types.count(in_node->GetType()) > 0) { + GELOGD("meet src_node %s", in_node->GetName().c_str()); + continue; + } + if (peer_data_anchor->GetPeerInDataNodesSize() != kInplaceSupportOutputNum) { + GELOGD("peer_data_anchor links with multi in_data_anchors"); + continue; + } + + int32_t inplace_input_idx = in_data_anchor->GetIdx(); + const DataType &input_type = node->GetOpDesc()->GetInputDesc(inplace_input_idx).GetDataType(); + const GeShape &input_shape = node->GetOpDesc()->GetInputDesc(inplace_input_idx).GetShape(); + if (input_type != output_type) { + GELOGD("DataType mismatch, in_idx=%d, input_type=%u, output_type=%u", inplace_input_idx, input_type, output_type); + continue; + } + if (input_shape.GetDims() != output_shape.GetDims()) { + GELOGD("Shape mismatch, in_idx=%d, input_shape=[%s], output_shape=[%s]", + inplace_input_idx, input_shape.ToString().c_str(), output_shape.ToString().c_str()); + continue; + } + + GELOGD("add attr INPLACE_SUPPORT_INPUT_INDEX on node %s, input_idx=%d", node->GetName().c_str(), inplace_input_idx); + if (!AttrUtils::SetInt(node->GetOpDesc()->MutableOutputDesc(kInplaceSupportOutputIndex), + INPLACE_SUPPORT_INPUT_INDEX, inplace_input_idx)) { + GELOGE(FAILED, "Set attr INPLACE_SUPPORT_INPUT_INDEX on node %s failed.", node->GetName().c_str()); + return FAILED; + } + AddRePassNode(node); + } + + GELOGD("InplaceSupportCheckPass success"); + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/inplace_support_check_pass.h b/ge/graph/passes/inplace_support_check_pass.h new file mode 100644 index 00000000..be2d6c75 --- /dev/null +++ b/ge/graph/passes/inplace_support_check_pass.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_INPLACE_SUPPORT_CHECK_PASS_H_ +#define GE_GRAPH_PASSES_INPLACE_SUPPORT_CHECK_PASS_H_ + +#include "graph/passes/base_pass.h" + +namespace ge { +class InplaceSupportCheckPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INPLACE_SUPPORT_CHECK_PASS_H_ diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index a7b922e0..392968e7 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -598,7 +598,7 @@ Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, cons /// Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, const std::set &same_cond_switch) { - GELOGD("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), + GELOGD("ModifySwitchInCtlEdges: switch_node=%s, cast_node=%s", switch_node->GetName().c_str(), cast_node->GetName().c_str()); std::string orig_switch_name = switch_node->GetName(); OpDescPtr switch_desc = switch_node->GetOpDesc(); diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index da862836..bea6aa6e 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -19,7 +19,6 @@ #include #include #include "common/formats/format_transfers/format_transfer_fractal_nz.h" -#include "common/formats/format_transfers/format_transfer_fractal_z.h" #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_transpose.h" @@ -38,7 +37,6 @@ #include "graph/passes/addn_pass.h" #include "graph/passes/aicpu_constant_folding_pass.h" #include "graph/passes/assert_pass.h" -#include "graph/passes/assign_pass.h" #include "graph/passes/common_subexpression_elimination_pass.h" #include "graph/passes/cond_pass.h" #include "graph/passes/cond_remove_pass.h" @@ -1699,7 +1697,9 @@ Status GraphPrepare::PrepareOptimize() { VarIsInitializedOpPass var_is_initialized_pass; ParallelConcatStartOpPass parallel_concat_start_op_pass; IdentityPass identity_pass(false); +#if (ENABLE_OPEN_SRC == True) AssignPass assign_pass; +#endif SnapshotPass snapshot_pass; if (!options_.train_graph_flag) { names_to_passes.emplace_back("DropOutPass", &dropout_pass); @@ -1714,9 +1714,11 @@ Status GraphPrepare::PrepareOptimize() { names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); names_to_passes.emplace_back("IdentityPass", &identity_pass); +#if (ENABLE_OPEN_SRC == True) if (GetContext().GetHostExecFlag()) { names_to_passes.emplace_back("AssignPass", &assign_pass); } +#endif GE_TIMESTAMP_START(names_to_passes); ret = ge_passes.Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::NamesToPasses"); diff --git a/ge/hybrid/common/npu_memory_allocator.cc b/ge/hybrid/common/npu_memory_allocator.cc index 2c38367a..7ed6a882 100644 --- a/ge/hybrid/common/npu_memory_allocator.cc +++ b/ge/hybrid/common/npu_memory_allocator.cc @@ -20,6 +20,9 @@ #include "graph/manager/graph_caching_allocator.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/rdma_pool_allocator.h" +#if (ENABLE_OPEN_SRC != True) +#include "graph/manager/host_mem_allocator.h" +#endif namespace ge { namespace hybrid { @@ -64,7 +67,11 @@ void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { if (mem_type == RDMA_HBM) { buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(allocate_size, device_id_); } else if (mem_type == HOST_DDR) { +#if (ENABLE_OPEN_SRC != True) + buffer = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(allocate_size); +#else buffer = malloc(allocate_size); +#endif } else { if (allocate_size > kMaxHbmMemorySize) { GELOGE(PARAM_INVALID, "Invalid HBM memory size: %zu", allocate_size); @@ -101,7 +108,11 @@ void NpuMemoryAllocator::Deallocate(void *data, MemStorageType mem_type) { if (mem_type == RDMA_HBM) { MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Free(reinterpret_cast(data), device_id_); } else if (mem_type == HOST_DDR) { +#if (ENABLE_OPEN_SRC != True) + MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Free(data); +#else free(data); +#endif } else { MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(reinterpret_cast(data), device_id_); } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index cba83dbd..195edfc4 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -24,11 +24,13 @@ #include "graph/manager/graph_var_manager.h" #include "graph/manager/host_mem_manager.h" #include "graph/manager/trans_var_data_utils.h" +#if (ENABLE_OPEN_SRC != True) +#include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/host_mem_allocator.h" +#endif #include "graph/utils/graph_utils.h" #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/node_executor/node_executor.h" -#include "framework/common/debug/ge_log.h" -#include "graph/utils/attr_utils.h" namespace ge { namespace hybrid { @@ -851,9 +853,24 @@ Status HybridModelBuilder::InitConstantOps() { std::unique_ptr var_tensor; if (GetContext().GetHostExecFlag()) { +#if (ENABLE_OPEN_SRC != True) + GE_CHECK_NOTNULL(ge_tensor); + // Address for eigen kernel should be aligned with 16 bytes + // Tensors return by api GetWeights share data with proto, whose addr is not confirmed to be aligned + GeTensor aligned_tensor = ge_tensor->Clone(); + GELOGD("Init tensor with host constant %s size = %zu", var_name.c_str(), aligned_tensor.MutableData().GetSize()); + if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(aligned_tensor.GetAlignedPtr(), + aligned_tensor.GetData().size()) == nullptr) { + GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed."); + return MEMALLOC_FAILED; + } + var_tensor.reset(new(std::nothrow)TensorValue(aligned_tensor.MutableData().data(), + aligned_tensor.GetData().size())); +#else auto buffer = ge_tensor->MutableData(); GELOGD("Init tensor with host constant. size = %zu", buffer.GetSize()); var_tensor.reset(new(std::nothrow)TensorValue(buffer.GetData(), buffer.GetSize())); +#endif } else { GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); @@ -908,9 +925,22 @@ Status HybridModelBuilder::InitVariableTensors() { GELOGE(GE_GRAPH_MALLOC_FAILED, "Host variable [%s] malloc failed.", it.first.c_str()); return GE_GRAPH_MALLOC_FAILED; } +#if (ENABLE_OPEN_SRC != True) + if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(mem_info.host_aligned_ptr, + tensor_size) == nullptr) { + GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed."); + return MEMALLOC_FAILED; + } + GELOGD("Host variable [%s] malloc success, host_addr=%p, dev_addr=%p, size=%lld.", + it.first.c_str(), mem_info.host_aligned_ptr->Get(), mem_info.device_address, tensor_size); + + std::unique_ptr tensor(new (std::nothrow) TensorValue(mem_info.host_aligned_ptr->MutableGet(), + tensor_size)); +#else GELOGD("Host variable [%s] malloc success.", it.first.c_str()); std::unique_ptr tensor(new (std::nothrow) TensorValue(mem_info.host_address, tensor_size)); +#endif GE_CHECK_NOTNULL(tensor); hybrid_model_.variable_tensors_.emplace(it.first, std::move(tensor)); } diff --git a/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc b/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc index a61195b0..9a76c24a 100755 --- a/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc +++ b/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc @@ -18,6 +18,10 @@ #include "hybrid/node_executor/host_cpu/kernel_factory.h" #include "graph/passes/folding_pass.h" #include "hybrid/model/hybrid_model.h" +#if (ENABLE_OPEN_SRC != True) +#include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/host_mem_allocator.h" +#endif #include "ge_local_engine/engine/host_cpu_engine.h" namespace ge { @@ -50,15 +54,23 @@ Status CpuKernelNodeTask::Execute(TaskContext &context) { auto input_desc_ptr = context.GetInputDesc(i); GE_CHECK_NOTNULL(input_desc_ptr); const auto &input_desc = *input_desc_ptr; +#if (ENABLE_OPEN_SRC != True) + auto tensor = context.GetInput(i); + GE_CHECK_NOTNULL(tensor); + auto item = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).GetAlignedPtr(tensor->GetData()); + GE_CHECK_NOTNULL(item.second); + auto in_tensor = MakeShared(input_desc, item.second, item.first); +#else GE_CHECK_NOTNULL(context.GetInput(i)); auto in_tensor = MakeShared(input_desc, reinterpret_cast(context.GetInput(i)->GetData()), context.GetInput(i)->GetSize()); +#endif GE_CHECK_NOTNULL(in_tensor); in_tensor->MutableTensorDesc().SetDataType(input_desc.GetDataType()); in_tensor->MutableTensorDesc().SetShape(input_desc.GetShape()); inputs.emplace_back(in_tensor); - GELOGI("node:%s allocate input %d, size=%zu", op_desc->GetName().c_str(), i, in_tensor->GetData().size()); + GELOGD("node:%s allocate input %d, size=%zu", op_desc->GetName().c_str(), i, in_tensor->GetData().size()); } std::vector outputs; @@ -72,14 +84,20 @@ Status CpuKernelNodeTask::Execute(TaskContext &context) { } auto tensor = context.GetOutput(i); GE_CHECK_NOTNULL(tensor); +#if (ENABLE_OPEN_SRC != True) + auto item = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).GetAlignedPtr(tensor->GetData()); + GE_CHECK_NOTNULL(item.second); + auto out_tensor = MakeShared(output_desc, item.second, item.first); +#else auto out_tensor = MakeShared(output_desc, reinterpret_cast(tensor->GetData()), tensor->GetSize()); +#endif GE_CHECK_NOTNULL(out_tensor); out_tensor->MutableTensorDesc().SetDataType(output_desc.GetDataType()); out_tensor->MutableTensorDesc().SetShape(output_desc.GetShape()); outputs.emplace_back(out_tensor); - GELOGI("node:%s allocate output %d, size=%zu", op_desc->GetName().c_str(), i, out_tensor->GetData().size()); + GELOGD("node:%s allocate output %d, size=%zu", op_desc->GetName().c_str(), i, out_tensor->GetData().size()); } return HostCpuEngine::GetInstance().Run(node_, inputs, outputs);