Browse Source

add tensor after fold

tags/v1.2.0
chenyemeng 3 years ago
parent
commit
fbc543626c
3 changed files with 20 additions and 16 deletions
  1. +9
    -5
      ge/ge_local_engine/engine/host_cpu_engine.cc
  2. +6
    -6
      ge/graph/passes/assign_remove_pass.cc
  3. +5
    -5
      ge/graph/passes/inplace_support_check_pass.cc

+ 9
- 5
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -232,12 +232,16 @@ Status HostCpuEngine::Run(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
std::vector<GeTensorPtr> tmp_outputs;
for (size_t i = 0; i < op_desc->GetOutputsSize(); i++) {
auto tensor_name = op_desc->GetOutputNameByIndex(i);
GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu",
op_desc->GetName().c_str(), i);
if (tensor_name.empty()) {
GELOGE(INTERNAL_ERROR, "Failed to get output name. node = %s, index = %zu", op_desc->GetName().c_str(), i);
return INTERNAL_ERROR;
}
auto iter = named_outputs.find(tensor_name);
GE_RETURN_WITH_LOG_IF_TRUE(iter == named_outputs.end(),
"Failed to get output tensor. node = %s, index = %zu, tensor_name = %s",
op_desc->GetName().c_str(), i, tensor_name.c_str());
if (iter == named_outputs.end()) {
GELOGE(INTERNAL_ERROR, "Failed to get output tensor. node = %s, index = %zu, tensor_name = %s",
op_desc->GetName().c_str(), i, tensor_name.c_str());
return INTERNAL_ERROR;
}
auto ge_tensor = MakeShared<GeTensor>(TensorAdapter::AsGeTensor(iter->second));
GE_CHECK_NOTNULL(ge_tensor);
tmp_outputs.emplace_back(ge_tensor);


+ 6
- 6
ge/graph/passes/assign_remove_pass.cc View File

@@ -21,12 +21,12 @@

namespace ge {
namespace {
constexpr uint32_t kValidInputNodeOutputNum = 1;
constexpr int32_t kAssignRefInputIndex = 0;
constexpr int32_t kAssignValueInputIndex = 1;
static const std::set<std::string> kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
constexpr uint32_t kValidInputNodeOutputNum = 1;
constexpr int32_t kAssignRefInputIndex = 0;
constexpr int32_t kAssignValueInputIndex = 1;
static const std::set<std::string> kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
}

Status AssignRemovePass::Run(NodePtr &node) {


+ 5
- 5
ge/graph/passes/inplace_support_check_pass.cc View File

@@ -21,11 +21,11 @@

namespace ge {
namespace {
constexpr uint32_t kInplaceSupportOutputIndex = 0;
constexpr uint32_t kInplaceSupportOutputNum = 1;
static const std::set<std::string> kSrcNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
constexpr uint32_t kInplaceSupportOutputIndex = 0;
constexpr uint32_t kInplaceSupportOutputNum = 1;
static const std::set<std::string> kSrcNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
}
Status InplaceSupportCheckPass::Run(NodePtr &node) {
GELOGD("InplaceSupportCheckPass running");


Loading…
Cancel
Save