@@ -124,7 +124,7 @@ set(TRAIN_SRC_LIST | |||||
"graph/manager/graph_var_manager.cc" | "graph/manager/graph_var_manager.cc" | ||||
"graph/manager/host_mem_manager.cc" | "graph/manager/host_mem_manager.cc" | ||||
"graph/manager/rdma_pool_allocator.cc" | "graph/manager/rdma_pool_allocator.cc" | ||||
$<$<NOT:$<STREQUAL:${ENABLE_OPEN_SRC},True>>:graph/manager/host_mem_allocator.cc> | |||||
"graph/manager/host_mem_allocator.cc" | |||||
"graph/manager/memory_api.cc" | "graph/manager/memory_api.cc" | ||||
"graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
"graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
@@ -166,7 +166,7 @@ set(TRAIN_SRC_LIST | |||||
"graph/passes/hccl_group_pass.cc" | "graph/passes/hccl_group_pass.cc" | ||||
"graph/passes/enter_pass.cc" | "graph/passes/enter_pass.cc" | ||||
"graph/passes/assign_remove_pass.cc" | "graph/passes/assign_remove_pass.cc" | ||||
$<$<NOT:$<STREQUAL:${ENABLE_OPEN_SRC},True>>:graph/passes/inplace_support_check_pass.cc> | |||||
"graph/passes/inplace_support_check_pass.cc" | |||||
"graph/passes/flow_ctrl_pass.cc" | "graph/passes/flow_ctrl_pass.cc" | ||||
"graph/passes/global_step_insert_pass.cc" | "graph/passes/global_step_insert_pass.cc" | ||||
"host_kernels/transpose_kernel.cc" | "host_kernels/transpose_kernel.cc" | ||||
@@ -409,7 +409,7 @@ set(INFER_SRC_LIST | |||||
"graph/manager/graph_var_manager.cc" | "graph/manager/graph_var_manager.cc" | ||||
"graph/manager/host_mem_manager.cc" | "graph/manager/host_mem_manager.cc" | ||||
"graph/manager/rdma_pool_allocator.cc" | "graph/manager/rdma_pool_allocator.cc" | ||||
$<$<NOT:$<STREQUAL:${ENABLE_OPEN_SRC},True>>:graph/manager/host_mem_allocator.cc> | |||||
"graph/manager/host_mem_allocator.cc" | |||||
"graph/manager/graph_mem_allocator.cc" | "graph/manager/graph_mem_allocator.cc" | ||||
"graph/manager/graph_caching_allocator.cc" | "graph/manager/graph_caching_allocator.cc" | ||||
"model/ge_model.cc" | "model/ge_model.cc" | ||||
@@ -531,7 +531,7 @@ set(INFER_SRC_LIST | |||||
"graph/passes/for_pass.cc" | "graph/passes/for_pass.cc" | ||||
"graph/passes/enter_pass.cc" | "graph/passes/enter_pass.cc" | ||||
"graph/passes/assign_remove_pass.cc" | "graph/passes/assign_remove_pass.cc" | ||||
$<$<NOT:$<STREQUAL:${ENABLE_OPEN_SRC},True>>:graph/passes/inplace_support_check_pass.cc> | |||||
"graph/passes/inplace_support_check_pass.cc" | |||||
"graph/passes/addn_pass.cc" | "graph/passes/addn_pass.cc" | ||||
"graph/passes/common_subexpression_elimination_pass.cc" | "graph/passes/common_subexpression_elimination_pass.cc" | ||||
"graph/passes/remove_same_const_pass.cc" | "graph/passes/remove_same_const_pass.cc" | ||||
@@ -28,7 +28,7 @@ set(SRC_LIST | |||||
"../graph/manager/trans_var_data_utils.cc" | "../graph/manager/trans_var_data_utils.cc" | ||||
"../graph/manager/util/debug.cc" | "../graph/manager/util/debug.cc" | ||||
"../graph/manager/rdma_pool_allocator.cc" | "../graph/manager/rdma_pool_allocator.cc" | ||||
$<$<NOT:$<STREQUAL:${ENABLE_OPEN_SRC},True>>:../graph/manager/host_mem_allocator.cc> | |||||
"../graph/manager/host_mem_allocator.cc" | |||||
"../hybrid/node_executor/aicpu/aicpu_ext_info.cc" | "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" | ||||
"../model/ge_model.cc" | "../model/ge_model.cc" | ||||
"../model/ge_root_model.cc" | "../model/ge_root_model.cc" | ||||
@@ -26,7 +26,6 @@ | |||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
namespace { | namespace { | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
#define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ | #define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ | ||||
case (DTYPE): { \ | case (DTYPE): { \ | ||||
GeTensorPtr ge_tensor = nullptr; \ | GeTensorPtr ge_tensor = nullptr; \ | ||||
@@ -50,43 +49,6 @@ namespace { | |||||
named_outputs.emplace(tensor_name, tensor); \ | named_outputs.emplace(tensor_name, tensor); \ | ||||
break; \ | break; \ | ||||
} | } | ||||
#else | |||||
#define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ | |||||
case (DTYPE): { \ | |||||
GeTensorPtr ge_tensor = nullptr; \ | |||||
if (need_create_flag) { \ | |||||
GELOGI("node:%s allocate output %zu start, size=%lld", op_desc->GetName().c_str(), i, data_num * sizeof(TYPE)); \ | |||||
std::unique_ptr<TYPE[]> buf(new (std::nothrow) TYPE[data_num]()); \ | |||||
if (buf == nullptr) { \ | |||||
GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", \ | |||||
static_cast<size_t>(sizeof(TYPE) * data_num)); \ | |||||
return MEMALLOC_FAILED; \ | |||||
} \ | |||||
ge_tensor = MakeShared<GeTensor>(out_desc); \ | |||||
GE_CHECK_NOTNULL(ge_tensor); \ | |||||
GELOGD("node:%s allocate output %zu success, size=%lld", op_desc->GetName().c_str(), i, data_num * sizeof(TYPE));\ | |||||
if (ge_tensor->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_num * sizeof(TYPE)) != GRAPH_SUCCESS) { \ | |||||
GELOGE(MEMALLOC_FAILED, "Set data for output %zu of node %s failed.", i, op_desc->GetName().c_str()); \ | |||||
return MEMALLOC_FAILED; \ | |||||
} \ | |||||
ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ | |||||
ge_tensor->MutableTensorDesc().SetShape(out_desc.GetShape()); \ | |||||
outputs.emplace_back(ge_tensor); \ | |||||
} else { \ | |||||
ge_tensor = outputs[i]; \ | |||||
GE_CHECK_NOTNULL(ge_tensor); \ | |||||
GELOGD("node:%s existed output %zu", op_desc->GetName().c_str(), i); \ | |||||
} \ | |||||
auto tensor = TensorAdapter::AsTensor(*ge_tensor); \ | |||||
auto tensor_name = op_desc->GetOutputNameByIndex(i); \ | |||||
GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu", \ | |||||
op_desc->GetName().c_str(), i); \ | |||||
GELOGD("Successfully inserted output tensor. node = %s, index = %zu, output name = %s, addr = %p, size = %zu", \ | |||||
op_desc->GetName().c_str(), i, tensor_name.c_str(), tensor.GetData(), tensor.GetSize()); \ | |||||
named_outputs.emplace(tensor_name, tensor); \ | |||||
break; \ | |||||
} | |||||
#endif | |||||
} | } | ||||
namespace ge { | namespace ge { | ||||
@@ -38,10 +38,8 @@ | |||||
#include "graph/partition/stage_partition.h" | #include "graph/partition/stage_partition.h" | ||||
#include "graph/passes/addn_pass.h" | #include "graph/passes/addn_pass.h" | ||||
#include "graph/passes/bitcast_pass.h" | #include "graph/passes/bitcast_pass.h" | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
#include "graph/passes/assign_remove_pass.h" | #include "graph/passes/assign_remove_pass.h" | ||||
#include "graph/passes/inplace_support_check_pass.h" | #include "graph/passes/inplace_support_check_pass.h" | ||||
#endif | |||||
#include "graph/passes/atomic_addr_clean_pass.h" | #include "graph/passes/atomic_addr_clean_pass.h" | ||||
#include "graph/passes/attach_stream_label_pass.h" | #include "graph/passes/attach_stream_label_pass.h" | ||||
#include "graph/passes/cast_remove_pass.h" | #include "graph/passes/cast_remove_pass.h" | ||||
@@ -2269,20 +2267,16 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { | |||||
ReshapeRemovePass reshape_remove_pass; | ReshapeRemovePass reshape_remove_pass; | ||||
CondRemovePass condition_remove_pass; | CondRemovePass condition_remove_pass; | ||||
BitcastPass bitcast_pass; | BitcastPass bitcast_pass; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
AssignRemovePass assign_remove_pass; | AssignRemovePass assign_remove_pass; | ||||
InplaceSupportCheckPass inplace_support_check_pass; | InplaceSupportCheckPass inplace_support_check_pass; | ||||
#endif | |||||
names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | ||||
names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); | names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); | ||||
names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); | names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); | ||||
names_to_passes.emplace_back("BitcastPass", &bitcast_pass); | names_to_passes.emplace_back("BitcastPass", &bitcast_pass); | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
if (GetContext().GetHostExecFlag()) { | if (GetContext().GetHostExecFlag()) { | ||||
names_to_passes.emplace_back("AssignRemovePass", &assign_remove_pass); | names_to_passes.emplace_back("AssignRemovePass", &assign_remove_pass); | ||||
names_to_passes.emplace_back("InplaceSupportCheckPass", &inplace_support_check_pass); | names_to_passes.emplace_back("InplaceSupportCheckPass", &inplace_support_check_pass); | ||||
} | } | ||||
#endif | |||||
GE_TIMESTAMP_START(names_to_passes); | GE_TIMESTAMP_START(names_to_passes); | ||||
ret = GEPass(compute_graph).Run(names_to_passes); | ret = GEPass(compute_graph).Run(names_to_passes); | ||||
GE_TIMESTAMP_END(names_to_passes, "OptimizeStage2::MergedGraphNameToPasses"); | GE_TIMESTAMP_END(names_to_passes, "OptimizeStage2::MergedGraphNameToPasses"); | ||||
@@ -19,9 +19,7 @@ | |||||
#include <string> | #include <string> | ||||
#include "graph/manager/graph_caching_allocator.h" | #include "graph/manager/graph_caching_allocator.h" | ||||
#include "graph/manager/rdma_pool_allocator.h" | #include "graph/manager/rdma_pool_allocator.h" | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
#include "graph/manager/host_mem_allocator.h" | #include "graph/manager/host_mem_allocator.h" | ||||
#endif | |||||
namespace ge { | namespace ge { | ||||
void MemoryAllocator::Initialize(uint32_t device_id) { | void MemoryAllocator::Initialize(uint32_t device_id) { | ||||
GELOGI("MemoryAllocator::Initialize"); | GELOGI("MemoryAllocator::Initialize"); | ||||
@@ -192,12 +190,10 @@ Status MemManager::Initialize(const std::vector<rtMemType_t> &memory_type) { | |||||
GELOGE(ge::INTERNAL_ERROR, "Create RdmaAllocator failed."); | GELOGE(ge::INTERNAL_ERROR, "Create RdmaAllocator failed."); | ||||
return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
} | } | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
if (InitAllocator(memory_type, host_allocator_map_) != SUCCESS) { | if (InitAllocator(memory_type, host_allocator_map_) != SUCCESS) { | ||||
GELOGE(ge::INTERNAL_ERROR, "Create HostMemAllocator failed."); | GELOGE(ge::INTERNAL_ERROR, "Create HostMemAllocator failed."); | ||||
return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
} | } | ||||
#endif | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -219,9 +215,7 @@ void MemManager::Finalize() noexcept { | |||||
// caching and rdma allocator use memory allocator, so finalize them first | // caching and rdma allocator use memory allocator, so finalize them first | ||||
FinalizeAllocatorMap(caching_allocator_map_); | FinalizeAllocatorMap(caching_allocator_map_); | ||||
FinalizeAllocatorMap(rdma_allocator_map_); | FinalizeAllocatorMap(rdma_allocator_map_); | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
FinalizeAllocatorMap(host_allocator_map_); | FinalizeAllocatorMap(host_allocator_map_); | ||||
#endif | |||||
FinalizeAllocatorMap(memory_allocator_map_); | FinalizeAllocatorMap(memory_allocator_map_); | ||||
} | } | ||||
@@ -250,9 +244,7 @@ CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { | |||||
RdmaPoolAllocator &MemManager::RdmaPoolInstance(rtMemType_t memory_type) { | RdmaPoolAllocator &MemManager::RdmaPoolInstance(rtMemType_t memory_type) { | ||||
return Instance().GetAllocator(memory_type, rdma_allocator_map_); | return Instance().GetAllocator(memory_type, rdma_allocator_map_); | ||||
} | } | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
HostMemAllocator &MemManager::HostMemInstance(rtMemType_t memory_type) { | HostMemAllocator &MemManager::HostMemInstance(rtMemType_t memory_type) { | ||||
return Instance().GetAllocator(memory_type, host_allocator_map_); | return Instance().GetAllocator(memory_type, host_allocator_map_); | ||||
} | } | ||||
#endif | |||||
} // namespace ge | } // namespace ge |
@@ -139,9 +139,7 @@ class MemoryAllocator { | |||||
using MemoryAllocatorPtr = std::shared_ptr<MemoryAllocator>; | using MemoryAllocatorPtr = std::shared_ptr<MemoryAllocator>; | ||||
class CachingAllocator; | class CachingAllocator; | ||||
class RdmaPoolAllocator; | class RdmaPoolAllocator; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
class HostMemAllocator; | class HostMemAllocator; | ||||
#endif | |||||
class MemManager { | class MemManager { | ||||
public: | public: | ||||
MemManager(); | MemManager(); | ||||
@@ -150,9 +148,7 @@ class MemManager { | |||||
static MemoryAllocator *Instance(rtMemType_t memory_type); | static MemoryAllocator *Instance(rtMemType_t memory_type); | ||||
CachingAllocator &CachingInstance(rtMemType_t memory_type); | CachingAllocator &CachingInstance(rtMemType_t memory_type); | ||||
RdmaPoolAllocator &RdmaPoolInstance(rtMemType_t memory_type); | RdmaPoolAllocator &RdmaPoolInstance(rtMemType_t memory_type); | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
HostMemAllocator &HostMemInstance(rtMemType_t memory_type); | HostMemAllocator &HostMemInstance(rtMemType_t memory_type); | ||||
#endif | |||||
MemManager(const MemManager &) = delete; | MemManager(const MemManager &) = delete; | ||||
MemManager &operator=(const MemManager &) = delete; | MemManager &operator=(const MemManager &) = delete; | ||||
/// | /// | ||||
@@ -240,9 +236,7 @@ class MemManager { | |||||
std::map<rtMemType_t, MemoryAllocator *> memory_allocator_map_; | std::map<rtMemType_t, MemoryAllocator *> memory_allocator_map_; | ||||
std::map<rtMemType_t, CachingAllocator *> caching_allocator_map_; | std::map<rtMemType_t, CachingAllocator *> caching_allocator_map_; | ||||
std::map<rtMemType_t, RdmaPoolAllocator *> rdma_allocator_map_; | std::map<rtMemType_t, RdmaPoolAllocator *> rdma_allocator_map_; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
std::map<rtMemType_t, HostMemAllocator *> host_allocator_map_; | std::map<rtMemType_t, HostMemAllocator *> host_allocator_map_; | ||||
#endif | |||||
std::recursive_mutex allocator_mutex_; | std::recursive_mutex allocator_mutex_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -27,7 +27,7 @@ | |||||
namespace ge { | namespace ge { | ||||
class HostMemAllocator { | class HostMemAllocator { | ||||
public: | public: | ||||
explicit HostMemAllocator(rtMemType_t) {} | |||||
explicit HostMemAllocator(rtMemType_t) {} | |||||
~HostMemAllocator() = default; | ~HostMemAllocator() = default; | ||||
HostMemAllocator(const HostMemAllocator &) = delete; | HostMemAllocator(const HostMemAllocator &) = delete; | ||||
@@ -43,29 +43,20 @@ Status SharedMemAllocator::Allocate(SharedMemInfo &mem_info) { | |||||
return GE_GRAPH_MEMORY_ALLOC_FAILED; | return GE_GRAPH_MEMORY_ALLOC_FAILED; | ||||
} | } | ||||
mem_info.fd = output_para.fd; | mem_info.fd = output_para.fd; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
mem_info.host_aligned_ptr = AlignedPtr::BuildFromAllocFunc([&output_para](std::unique_ptr<uint8_t[], deleter> &ptr) { | mem_info.host_aligned_ptr = AlignedPtr::BuildFromAllocFunc([&output_para](std::unique_ptr<uint8_t[], deleter> &ptr) { | ||||
ptr.reset(reinterpret_cast<uint8_t *>(output_para.ptr)); | ptr.reset(reinterpret_cast<uint8_t *>(output_para.ptr)); | ||||
}, | }, | ||||
[](uint8_t *ptr) { | [](uint8_t *ptr) { | ||||
ptr = nullptr; | ptr = nullptr; | ||||
}); | }); | ||||
#else | |||||
mem_info.host_address = reinterpret_cast<uint8_t *>(output_para.ptr); | |||||
#endif | |||||
mem_info.device_address = reinterpret_cast<uint8_t *>(output_para.devPtr); | mem_info.device_address = reinterpret_cast<uint8_t *>(output_para.devPtr); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status SharedMemAllocator::DeAllocate(SharedMemInfo &mem_info) { | Status SharedMemAllocator::DeAllocate(SharedMemInfo &mem_info) { | ||||
GELOGD("SharedMemAllocator::DeAllocate"); | GELOGD("SharedMemAllocator::DeAllocate"); | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
rtFreeHostSharedMemoryIn free_para = {mem_info.shm_name.c_str(), mem_info.mem_size, mem_info.fd, | rtFreeHostSharedMemoryIn free_para = {mem_info.shm_name.c_str(), mem_info.mem_size, mem_info.fd, | ||||
mem_info.host_aligned_ptr->MutableGet(), mem_info.device_address}; | mem_info.host_aligned_ptr->MutableGet(), mem_info.device_address}; | ||||
#else | |||||
rtFreeHostSharedMemoryIn free_para = {mem_info.shm_name.c_str(), mem_info.mem_size, mem_info.fd, | |||||
mem_info.host_address, mem_info.device_address}; | |||||
#endif | |||||
rtError_t rt_ret = rtFreeHostSharedMemory(&free_para); | rtError_t rt_ret = rtFreeHostSharedMemory(&free_para); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api(rtFreeHostSharedMemory) failed, ret: 0x%X.", rt_ret); | GELOGE(RT_FAILED, "Call rt api(rtFreeHostSharedMemory) failed, ret: 0x%X.", rt_ret); | ||||
@@ -42,11 +42,7 @@ struct SharedMemInfo { | |||||
uint64_t mem_size = 0; | uint64_t mem_size = 0; | ||||
int fd = 0; | int fd = 0; | ||||
uint8_t *device_address = nullptr; | uint8_t *device_address = nullptr; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
std::shared_ptr<AlignedPtr> host_aligned_ptr = nullptr; | std::shared_ptr<AlignedPtr> host_aligned_ptr = nullptr; | ||||
#else | |||||
uint8_t *host_address = nullptr; | |||||
#endif | |||||
SharedMemInfo() = default; | SharedMemInfo() = default; | ||||
SharedMemInfo(string name, uint64_t size) : op_name(std::move(name)), mem_size(size) {} | SharedMemInfo(string name, uint64_t size) : op_name(std::move(name)), mem_size(size) {} | ||||
}; | }; | ||||
@@ -127,6 +127,10 @@ Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std | |||||
} | } | ||||
Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | ||||
if (GetContext().GetHostExecFlag()) { | |||||
// graph exec on host, no need OptimizeOriginalGraph | |||||
return SUCCESS; | |||||
} | |||||
if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeOriginalGraph]: compute_graph is nullptr."); | GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeOriginalGraph]: compute_graph is nullptr."); | ||||
return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | ||||
@@ -162,7 +166,7 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | |||||
Status GraphOptimize::OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_graph) { | Status GraphOptimize::OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_graph) { | ||||
GELOGD("OptimizeOriginalGraphJudgeInsert in"); | GELOGD("OptimizeOriginalGraphJudgeInsert in"); | ||||
if (GetContext().GetHostExecFlag()) { | if (GetContext().GetHostExecFlag()) { | ||||
// graph exec on host, no need OptimizeOriginalGraph | |||||
// graph exec on host, no need OptimizeOriginalGraphJudgeInsert | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -29,7 +29,6 @@ static const std::set<std::string> kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, | |||||
} | } | ||||
namespace ge { | namespace ge { | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
Status AssignRemovePass::Run(NodePtr &node) { | Status AssignRemovePass::Run(NodePtr &node) { | ||||
GELOGD("AssignRemovePass running"); | GELOGD("AssignRemovePass running"); | ||||
@@ -145,71 +144,7 @@ Status AssignRemovePass::TransformAttr(NodePtr &node) { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
#else | |||||
Status AssignRemovePass::Run(NodePtr &node) { | |||||
GELOGD("AssignRemovePass running"); | |||||
if (node->GetType() != ASSIGN) { | |||||
GELOGD("No need run AssignRemovePass on [%s, %s].", node->GetName().c_str(), node->GetType().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
const auto &ref_in_anchor = node->GetInDataAnchor(kAssignRefInputIndex); | |||||
const auto &value_in_anchor = node->GetInDataAnchor(kAssignValueInputIndex); | |||||
if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) { | |||||
GELOGE(FAILED, "In data anchor is null, node:%s", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor(); | |||||
const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor(); | |||||
if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) { | |||||
GELOGE(FAILED, "Peer data anchor is null, node:%s", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (IsCondMatch(node, ref_peer_anchor, value_peer_anchor)) { | |||||
/// | |||||
/// variable not-const not-const | |||||
/// \ / | | |||||
/// \ / | | |||||
/// Assign ----> variable | |||||
/// | | | |||||
/// | | | |||||
/// node node | |||||
/// | |||||
GELOGI("Optimization for assign_node %s start", node->GetName().c_str()); | |||||
if (IsolateAndDeleteNode(node, {kAssignRefInputIndex}) != SUCCESS) { | |||||
GELOGE(FAILED, "Isolate and delete assign_node %s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
AddNodeDeleted(node); | |||||
const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc(); | |||||
const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc(); | |||||
if ((ref_input == nullptr) || (value_input == nullptr)) { | |||||
GELOGE(FAILED, "value input is null"); | |||||
return FAILED; | |||||
} | |||||
if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME, | |||||
ref_input->GetName())) { | |||||
GELOGE(FAILED, "Set attr ASSIGN_VAR_NAME failed."); | |||||
return FAILED; | |||||
} | |||||
// variable has and only has one input | |||||
if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Update input_desc for variable %s failed.", ref_input->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Add data edge %s->%s failed", value_input->GetName().c_str(), ref_input->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
GELOGD("AssignRemovePass success"); | |||||
return SUCCESS; | |||||
} | |||||
#endif | |||||
/// | /// | ||||
/// @brief Check if need optimize for assign_node | /// @brief Check if need optimize for assign_node | ||||
/// @param [in] assign_node | /// @param [in] assign_node | ||||
@@ -218,7 +153,7 @@ Status AssignRemovePass::Run(NodePtr &node) { | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
bool AssignRemovePass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_peer_anchor, | bool AssignRemovePass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_peer_anchor, | ||||
const OutDataAnchorPtr &value_peer_anchor) { | |||||
const OutDataAnchorPtr &value_peer_anchor) { | |||||
GELOGD("Check if assign_node %s match optimization condition, ref_input: %s, value_input: %s", | GELOGD("Check if assign_node %s match optimization condition, ref_input: %s, value_input: %s", | ||||
node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(), | node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(), | ||||
value_peer_anchor->GetOwnerNode()->GetName().c_str()); | value_peer_anchor->GetOwnerNode()->GetName().c_str()); | ||||
@@ -25,7 +25,6 @@ class AssignRemovePass : public BaseNodePass { | |||||
Status Run(NodePtr &node) override; | Status Run(NodePtr &node) override; | ||||
private: | private: | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
/// | /// | ||||
/// @brief Optimize for assign_node | /// @brief Optimize for assign_node | ||||
/// @param [in] assign_node | /// @param [in] assign_node | ||||
@@ -39,7 +38,7 @@ class AssignRemovePass : public BaseNodePass { | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status TransformAttr(NodePtr &node); | Status TransformAttr(NodePtr &node); | ||||
#endif | |||||
/// | /// | ||||
/// @brief Check if need optimize for assign_node | /// @brief Check if need optimize for assign_node | ||||
/// @param [in] assign_node | /// @param [in] assign_node | ||||
@@ -115,21 +115,15 @@ void ConstantFuseSamePass::GetFuseConstNodes(ComputeGraphPtr &graph, | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
if ((type_size != 0) && (weight->MutableData().GetAlignedPtr() == nullptr)) { | if ((type_size != 0) && (weight->MutableData().GetAlignedPtr() == nullptr)) { | ||||
GELOGW("aligned_ptr is null while size is not 0"); | GELOGW("aligned_ptr is null while size is not 0"); | ||||
continue; | continue; | ||||
} | } | ||||
#endif | |||||
++insert_const_nums; | ++insert_const_nums; | ||||
SameConstKey map_key; | SameConstKey map_key; | ||||
map_key.data_size = type_size; | map_key.data_size = type_size; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
map_key.aligned_ptr = weight->MutableData().GetAlignedPtr(); | map_key.aligned_ptr = weight->MutableData().GetAlignedPtr(); | ||||
#else | |||||
map_key.data = weight->GetData().GetData(); | |||||
#endif | |||||
map_key.data_type = data_type; | map_key.data_type = data_type; | ||||
map_key.format = output_tensor->GetFormat(); | map_key.format = output_tensor->GetFormat(); | ||||
map_key.shape = output_tensor->GetShape().GetDims(); | map_key.shape = output_tensor->GetShape().GetDims(); | ||||
@@ -21,20 +21,14 @@ | |||||
#include <set> | #include <set> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
#include "graph/aligned_ptr.h" | #include "graph/aligned_ptr.h" | ||||
#endif | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "inc/graph_pass.h" | #include "inc/graph_pass.h" | ||||
namespace ge { | namespace ge { | ||||
struct SameConstKey { | struct SameConstKey { | ||||
int data_size; | int data_size; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
std::shared_ptr<AlignedPtr> aligned_ptr; | std::shared_ptr<AlignedPtr> aligned_ptr; | ||||
#else | |||||
const uint8_t *data; | |||||
#endif | |||||
DataType data_type; | DataType data_type; | ||||
Format format; | Format format; | ||||
std::vector<int64_t> shape; | std::vector<int64_t> shape; | ||||
@@ -44,19 +38,12 @@ struct SameConstKey { | |||||
if (data_size != key.data_size) { | if (data_size != key.data_size) { | ||||
return data_size < key.data_size; | return data_size < key.data_size; | ||||
} | } | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
if (data_size != 0) { | if (data_size != 0) { | ||||
int ret = memcmp(aligned_ptr->Get(), key.aligned_ptr->Get(), data_size); | int ret = memcmp(aligned_ptr->Get(), key.aligned_ptr->Get(), data_size); | ||||
if (ret != 0) { | if (ret != 0) { | ||||
return ret < 0; | return ret < 0; | ||||
} | } | ||||
} | } | ||||
#else | |||||
int ret = memcmp(data, key.data, data_size); | |||||
if (ret != 0) { | |||||
return ret < 0; | |||||
} | |||||
#endif | |||||
if (data_type != key.data_type) { | if (data_type != key.data_type) { | ||||
return data_type < key.data_type; | return data_type < key.data_type; | ||||
} | } | ||||
@@ -38,9 +38,6 @@ | |||||
#include "graph/passes/aicpu_constant_folding_pass.h" | #include "graph/passes/aicpu_constant_folding_pass.h" | ||||
#include "graph/passes/assert_pass.h" | #include "graph/passes/assert_pass.h" | ||||
#include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
#ifdef ONLY_COMPILE_OPEN_SRC | |||||
#include "graph/passes/assign_remove_pass.h" | |||||
#endif | |||||
#include "graph/passes/common_subexpression_elimination_pass.h" | #include "graph/passes/common_subexpression_elimination_pass.h" | ||||
#include "graph/passes/cond_pass.h" | #include "graph/passes/cond_pass.h" | ||||
#include "graph/passes/cond_remove_pass.h" | #include "graph/passes/cond_remove_pass.h" | ||||
@@ -1865,9 +1862,6 @@ Status GraphPrepare::PrepareOptimize() { | |||||
VarIsInitializedOpPass var_is_initialized_pass; | VarIsInitializedOpPass var_is_initialized_pass; | ||||
ParallelConcatStartOpPass parallel_concat_start_op_pass; | ParallelConcatStartOpPass parallel_concat_start_op_pass; | ||||
IdentityPass identity_pass(false); | IdentityPass identity_pass(false); | ||||
#ifdef ONLY_COMPILE_OPEN_SRC | |||||
AssignRemovePass assign_remove_pass; | |||||
#endif | |||||
SnapshotPass snapshot_pass; | SnapshotPass snapshot_pass; | ||||
if (!options_.train_graph_flag) { | if (!options_.train_graph_flag) { | ||||
names_to_passes.emplace_back("DropOutPass", &dropout_pass); | names_to_passes.emplace_back("DropOutPass", &dropout_pass); | ||||
@@ -1882,11 +1876,6 @@ Status GraphPrepare::PrepareOptimize() { | |||||
names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); | names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); | ||||
names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); | names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); | ||||
names_to_passes.emplace_back("IdentityPass", &identity_pass); | names_to_passes.emplace_back("IdentityPass", &identity_pass); | ||||
#ifdef ONLY_COMPILE_OPEN_SRC | |||||
if (GetContext().GetHostExecFlag()) { | |||||
names_to_passes.emplace_back("AssignRemovePass", &assign_remove_pass); | |||||
} | |||||
#endif | |||||
GE_TIMESTAMP_START(names_to_passes); | GE_TIMESTAMP_START(names_to_passes); | ||||
ret = ge_passes.Run(names_to_passes); | ret = ge_passes.Run(names_to_passes); | ||||
GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::NamesToPasses"); | GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::NamesToPasses"); | ||||
@@ -20,9 +20,7 @@ | |||||
#include "graph/manager/graph_caching_allocator.h" | #include "graph/manager/graph_caching_allocator.h" | ||||
#include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
#include "graph/manager/rdma_pool_allocator.h" | #include "graph/manager/rdma_pool_allocator.h" | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
#include "graph/manager/host_mem_allocator.h" | #include "graph/manager/host_mem_allocator.h" | ||||
#endif | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -67,11 +65,7 @@ void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | |||||
if (mem_type == RDMA_HBM) { | if (mem_type == RDMA_HBM) { | ||||
buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(allocate_size, device_id_); | buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(allocate_size, device_id_); | ||||
} else if (mem_type == HOST_DDR) { | } else if (mem_type == HOST_DDR) { | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
buffer = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(allocate_size); | buffer = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(allocate_size); | ||||
#else | |||||
buffer = malloc(allocate_size); | |||||
#endif | |||||
} else { | } else { | ||||
if (allocate_size > kMaxHbmMemorySize) { | if (allocate_size > kMaxHbmMemorySize) { | ||||
GELOGE(PARAM_INVALID, "Invalid HBM memory size: %zu", allocate_size); | GELOGE(PARAM_INVALID, "Invalid HBM memory size: %zu", allocate_size); | ||||
@@ -108,11 +102,7 @@ void NpuMemoryAllocator::Deallocate(void *data, MemStorageType mem_type) { | |||||
if (mem_type == RDMA_HBM) { | if (mem_type == RDMA_HBM) { | ||||
MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Free(reinterpret_cast<uint8_t *>(data), device_id_); | MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Free(reinterpret_cast<uint8_t *>(data), device_id_); | ||||
} else if (mem_type == HOST_DDR) { | } else if (mem_type == HOST_DDR) { | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Free(data); | MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Free(data); | ||||
#else | |||||
free(data); | |||||
#endif | |||||
} else { | } else { | ||||
MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(reinterpret_cast<uint8_t *>(data), device_id_); | MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(reinterpret_cast<uint8_t *>(data), device_id_); | ||||
} | } | ||||
@@ -25,10 +25,8 @@ | |||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "graph/manager/host_mem_manager.h" | #include "graph/manager/host_mem_manager.h" | ||||
#include "graph/manager/trans_var_data_utils.h" | #include "graph/manager/trans_var_data_utils.h" | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
#include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
#include "graph/manager/host_mem_allocator.h" | #include "graph/manager/host_mem_allocator.h" | ||||
#endif | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "hybrid/common/npu_memory_allocator.h" | #include "hybrid/common/npu_memory_allocator.h" | ||||
#include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
@@ -865,7 +863,6 @@ Status HybridModelBuilder::InitConstantOps() { | |||||
std::unique_ptr<TensorValue> var_tensor; | std::unique_ptr<TensorValue> var_tensor; | ||||
if (GetContext().GetHostExecFlag()) { | if (GetContext().GetHostExecFlag()) { | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
GE_CHECK_NOTNULL(ge_tensor); | GE_CHECK_NOTNULL(ge_tensor); | ||||
// Address for eigen kernel should be aligned with 16 bytes | // Address for eigen kernel should be aligned with 16 bytes | ||||
// Tensors return by api GetWeights share data with proto, whose addr is not confirmed to be aligned | // Tensors return by api GetWeights share data with proto, whose addr is not confirmed to be aligned | ||||
@@ -878,11 +875,6 @@ Status HybridModelBuilder::InitConstantOps() { | |||||
} | } | ||||
var_tensor.reset(new(std::nothrow)TensorValue(aligned_tensor.MutableData().data(), | var_tensor.reset(new(std::nothrow)TensorValue(aligned_tensor.MutableData().data(), | ||||
aligned_tensor.GetData().size())); | aligned_tensor.GetData().size())); | ||||
#else | |||||
auto buffer = ge_tensor->MutableData(); | |||||
GELOGD("Init tensor with host constant. size = %zu", buffer.GetSize()); | |||||
var_tensor.reset(new(std::nothrow)TensorValue(buffer.GetData(), buffer.GetSize())); | |||||
#endif | |||||
} else { | } else { | ||||
GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); | GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); | ||||
GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); | GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); | ||||
@@ -937,7 +929,6 @@ Status HybridModelBuilder::InitVariableTensors() { | |||||
GELOGE(GE_GRAPH_MALLOC_FAILED, "Host variable [%s] malloc failed.", it.first.c_str()); | GELOGE(GE_GRAPH_MALLOC_FAILED, "Host variable [%s] malloc failed.", it.first.c_str()); | ||||
return GE_GRAPH_MALLOC_FAILED; | return GE_GRAPH_MALLOC_FAILED; | ||||
} | } | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(mem_info.host_aligned_ptr, | if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(mem_info.host_aligned_ptr, | ||||
tensor_size) == nullptr) { | tensor_size) == nullptr) { | ||||
GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed."); | GELOGE(MEMALLOC_FAILED, "Malloc host memory for an existed GeTensor failed."); | ||||
@@ -947,11 +938,6 @@ Status HybridModelBuilder::InitVariableTensors() { | |||||
std::unique_ptr<TensorValue> tensor(new (std::nothrow) TensorValue(mem_info.host_aligned_ptr->MutableGet(), | std::unique_ptr<TensorValue> tensor(new (std::nothrow) TensorValue(mem_info.host_aligned_ptr->MutableGet(), | ||||
tensor_size)); | tensor_size)); | ||||
#else | |||||
GELOGD("Host variable [%s] malloc success.", it.first.c_str()); | |||||
std::unique_ptr<TensorValue> tensor(new (std::nothrow) TensorValue(mem_info.host_address, tensor_size)); | |||||
#endif | |||||
GE_CHECK_NOTNULL(tensor); | GE_CHECK_NOTNULL(tensor); | ||||
hybrid_model_.variable_tensors_.emplace(it.first, std::move(tensor)); | hybrid_model_.variable_tensors_.emplace(it.first, std::move(tensor)); | ||||
} | } | ||||
@@ -18,10 +18,8 @@ | |||||
#include "hybrid/node_executor/host_cpu/kernel_factory.h" | #include "hybrid/node_executor/host_cpu/kernel_factory.h" | ||||
#include "graph/passes/folding_pass.h" | #include "graph/passes/folding_pass.h" | ||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
#include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
#include "graph/manager/host_mem_allocator.h" | #include "graph/manager/host_mem_allocator.h" | ||||
#endif | |||||
#include "ge_local_engine/engine/host_cpu_engine.h" | #include "ge_local_engine/engine/host_cpu_engine.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -54,18 +52,11 @@ Status CpuKernelNodeTask::Execute(TaskContext &context) { | |||||
auto input_desc_ptr = context.GetInputDesc(i); | auto input_desc_ptr = context.GetInputDesc(i); | ||||
GE_CHECK_NOTNULL(input_desc_ptr); | GE_CHECK_NOTNULL(input_desc_ptr); | ||||
const auto &input_desc = *input_desc_ptr; | const auto &input_desc = *input_desc_ptr; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
auto tensor = context.GetInput(i); | auto tensor = context.GetInput(i); | ||||
GE_CHECK_NOTNULL(tensor); | GE_CHECK_NOTNULL(tensor); | ||||
auto item = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).GetAlignedPtr(tensor->GetData()); | auto item = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).GetAlignedPtr(tensor->GetData()); | ||||
GE_CHECK_NOTNULL(item.second); | GE_CHECK_NOTNULL(item.second); | ||||
auto in_tensor = MakeShared<GeTensor>(input_desc, item.second, item.first); | auto in_tensor = MakeShared<GeTensor>(input_desc, item.second, item.first); | ||||
#else | |||||
GE_CHECK_NOTNULL(context.GetInput(i)); | |||||
auto in_tensor = MakeShared<GeTensor>(input_desc, | |||||
reinterpret_cast<const uint8_t *>(context.GetInput(i)->GetData()), | |||||
context.GetInput(i)->GetSize()); | |||||
#endif | |||||
GE_CHECK_NOTNULL(in_tensor); | GE_CHECK_NOTNULL(in_tensor); | ||||
in_tensor->MutableTensorDesc().SetDataType(input_desc.GetDataType()); | in_tensor->MutableTensorDesc().SetDataType(input_desc.GetDataType()); | ||||
in_tensor->MutableTensorDesc().SetShape(input_desc.GetShape()); | in_tensor->MutableTensorDesc().SetShape(input_desc.GetShape()); | ||||
@@ -84,15 +75,9 @@ Status CpuKernelNodeTask::Execute(TaskContext &context) { | |||||
} | } | ||||
auto tensor = context.GetOutput(i); | auto tensor = context.GetOutput(i); | ||||
GE_CHECK_NOTNULL(tensor); | GE_CHECK_NOTNULL(tensor); | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
auto item = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).GetAlignedPtr(tensor->GetData()); | auto item = MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).GetAlignedPtr(tensor->GetData()); | ||||
GE_CHECK_NOTNULL(item.second); | GE_CHECK_NOTNULL(item.second); | ||||
auto out_tensor = MakeShared<GeTensor>(output_desc, item.second, item.first); | auto out_tensor = MakeShared<GeTensor>(output_desc, item.second, item.first); | ||||
#else | |||||
auto out_tensor = MakeShared<GeTensor>(output_desc, | |||||
reinterpret_cast<const uint8_t *>(tensor->GetData()), | |||||
tensor->GetSize()); | |||||
#endif | |||||
GE_CHECK_NOTNULL(out_tensor); | GE_CHECK_NOTNULL(out_tensor); | ||||
out_tensor->MutableTensorDesc().SetDataType(output_desc.GetDataType()); | out_tensor->MutableTensorDesc().SetDataType(output_desc.GetDataType()); | ||||
out_tensor->MutableTensorDesc().SetShape(output_desc.GetShape()); | out_tensor->MutableTensorDesc().SetShape(output_desc.GetShape()); | ||||