From fbc543626c0a17c78ed5d8074daa73ba50ad71e5 Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Tue, 5 Jan 2021 15:18:03 +0800 Subject: [PATCH] add tensor after fold --- ge/ge_local_engine/engine/host_cpu_engine.cc | 14 +++++++++----- ge/graph/passes/assign_remove_pass.cc | 12 ++++++------ ge/graph/passes/inplace_support_check_pass.cc | 10 +++++----- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc index 1197f466..99ee8794 100755 --- a/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -232,12 +232,16 @@ Status HostCpuEngine::Run(NodePtr &node, const vector &inputs, std::vector tmp_outputs; for (size_t i = 0; i < op_desc->GetOutputsSize(); i++) { auto tensor_name = op_desc->GetOutputNameByIndex(i); - GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu", - op_desc->GetName().c_str(), i); + if (tensor_name.empty()) { + GELOGE(INTERNAL_ERROR, "Failed to get output name. node = %s, index = %zu", op_desc->GetName().c_str(), i); + return INTERNAL_ERROR; + } auto iter = named_outputs.find(tensor_name); - GE_RETURN_WITH_LOG_IF_TRUE(iter == named_outputs.end(), - "Failed to get output tensor. node = %s, index = %zu, tensor_name = %s", - op_desc->GetName().c_str(), i, tensor_name.c_str()); + if (iter == named_outputs.end()) { + GELOGE(INTERNAL_ERROR, "Failed to get output tensor. node = %s, index = %zu, tensor_name = %s", + op_desc->GetName().c_str(), i, tensor_name.c_str()); + return INTERNAL_ERROR; + } auto ge_tensor = MakeShared(TensorAdapter::AsGeTensor(iter->second)); GE_CHECK_NOTNULL(ge_tensor); tmp_outputs.emplace_back(ge_tensor); diff --git a/ge/graph/passes/assign_remove_pass.cc b/ge/graph/passes/assign_remove_pass.cc index fc971cc0..e198c2db 100644 --- a/ge/graph/passes/assign_remove_pass.cc +++ b/ge/graph/passes/assign_remove_pass.cc @@ -21,12 +21,12 @@ namespace ge { namespace { - constexpr uint32_t kValidInputNodeOutputNum = 1; - constexpr int32_t kAssignRefInputIndex = 0; - constexpr int32_t kAssignValueInputIndex = 1; - static const std::set kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA, - ge::CONSTANT, ge::CONSTANTOP, - ge::VARIABLE, ge::VARIABLEV2 }; +constexpr uint32_t kValidInputNodeOutputNum = 1; +constexpr int32_t kAssignRefInputIndex = 0; +constexpr int32_t kAssignValueInputIndex = 1; +static const std::set kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA, + ge::CONSTANT, ge::CONSTANTOP, + ge::VARIABLE, ge::VARIABLEV2 }; } Status AssignRemovePass::Run(NodePtr &node) { diff --git a/ge/graph/passes/inplace_support_check_pass.cc b/ge/graph/passes/inplace_support_check_pass.cc index 9f683751..44ad8361 100644 --- a/ge/graph/passes/inplace_support_check_pass.cc +++ b/ge/graph/passes/inplace_support_check_pass.cc @@ -21,11 +21,11 @@ namespace ge { namespace { - constexpr uint32_t kInplaceSupportOutputIndex = 0; - constexpr uint32_t kInplaceSupportOutputNum = 1; - static const std::set kSrcNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA, - ge::CONSTANT, ge::CONSTANTOP, - ge::VARIABLE, ge::VARIABLEV2 }; +constexpr uint32_t kInplaceSupportOutputIndex = 0; +constexpr uint32_t kInplaceSupportOutputNum = 1; +static const std::set kSrcNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA, + ge::CONSTANT, ge::CONSTANTOP, + ge::VARIABLE, ge::VARIABLEV2 }; } Status InplaceSupportCheckPass::Run(NodePtr &node) { GELOGD("InplaceSupportCheckPass running");