Browse Source

add tensor after fold

tags/v1.2.0
chenyemeng 3 years ago
parent
commit
3d359b01da
4 changed files with 27 additions and 13 deletions
  1. +12
    -2
      ge/ge_local_engine/engine/host_cpu_engine.cc
  2. +4
    -0
      ge/graph/build/model_builder.cc
  3. +6
    -6
      ge/graph/passes/assign_remove_pass.cc
  4. +5
    -5
      ge/graph/passes/inplace_support_check_pass.cc

+ 12
- 2
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -221,7 +221,7 @@ Status HostCpuEngine::Run(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
GELOGD("Run node by host cpu engine. node name = %s", node->GetName().c_str());
std::unique_ptr<HostCpuOp> op_kernel;
GE_CHK_STATUS_RET_NOLOG(FindOpKernel(node, op_kernel));
#ifndef ONLY_COMPILE_OPEN_SRC
std::map<std::string, const Tensor> named_inputs;
std::map<std::string, Tensor> named_outputs;
auto op_desc = node->GetOpDesc();
@@ -229,7 +229,6 @@ Status HostCpuEngine::Run(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
GE_CHK_STATUS_RET_NOLOG(PrepareOutputs(op_desc, outputs, named_outputs));
GE_CHK_STATUS_RET_NOLOG(RunInternal(op_desc, *op_kernel, named_inputs, named_outputs));

GELOGD("Run node by host cpu engine successfully. name node = %s", node->GetName().c_str());
std::vector<GeTensorPtr> tmp_outputs;
for (size_t i = 0; i < op_desc->GetOutputsSize(); i++) {
auto tensor_name = op_desc->GetOutputNameByIndex(i);
@@ -243,6 +242,17 @@ Status HostCpuEngine::Run(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
GE_CHECK_NOTNULL(ge_tensor);
tmp_outputs.emplace_back(ge_tensor);
}
#else
std::map<std::string, const Tensor> named_inputs;
std::vector<GeTensorPtr> tmp_outputs;
tmp_outputs.swap(outputs);
std::map<std::string, Tensor> named_outputs;
auto op_desc = node->GetOpDesc();
GE_CHK_STATUS_RET_NOLOG(PrepareInputs(op_desc, inputs, named_inputs));
GE_CHK_STATUS_RET_NOLOG(PrepareOutputs(op_desc, tmp_outputs, named_outputs));
GE_CHK_STATUS_RET_NOLOG(RunInternal(op_desc, *op_kernel, named_inputs, named_outputs));
#endif
GELOGD("Run node by host cpu engine successfully. name node = %s", node->GetName().c_str());
outputs.swap(tmp_outputs);
return SUCCESS;
}


+ 4
- 0
ge/graph/build/model_builder.cc View File

@@ -569,7 +569,11 @@ Status ModelBuilder::MergeWeights() {
return FAILED;
}
}
#ifndef ONLY_COMPILE_OPEN_SRC
weight->ClearData();
#else
weight_data.clear();
#endif
}

return SUCCESS;


+ 6
- 6
ge/graph/passes/assign_remove_pass.cc View File

@@ -21,12 +21,12 @@

namespace ge {
namespace {
constexpr uint32_t kValidInputNodeOutputNum = 1;
constexpr int32_t kAssignRefInputIndex = 0;
constexpr int32_t kAssignValueInputIndex = 1;
static const std::set<std::string> kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
constexpr uint32_t kValidInputNodeOutputNum = 1;
constexpr int32_t kAssignRefInputIndex = 0;
constexpr int32_t kAssignValueInputIndex = 1;
static const std::set<std::string> kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
}

Status AssignRemovePass::Run(NodePtr &node) {


+ 5
- 5
ge/graph/passes/inplace_support_check_pass.cc View File

@@ -21,11 +21,11 @@

namespace ge {
namespace {
constexpr uint32_t kInplaceSupportOutputIndex = 0;
constexpr uint32_t kInplaceSupportOutputNum = 1;
static const std::set<std::string> kSrcNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
constexpr uint32_t kInplaceSupportOutputIndex = 0;
constexpr uint32_t kInplaceSupportOutputNum = 1;
static const std::set<std::string> kSrcNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
ge::CONSTANT, ge::CONSTANTOP,
ge::VARIABLE, ge::VARIABLEV2 };
}
Status InplaceSupportCheckPass::Run(NodePtr &node) {
GELOGD("InplaceSupportCheckPass running");


Loading…
Cancel
Save