Browse Source

Feature:atc single op support optional input!

tags/v1.1.0
wxl 3 years ago
parent
commit
d9c98b8715
5 changed files with 44 additions and 23 deletions
  1. +29
    -3
      ge/generator/ge_generator.cc
  2. +5
    -10
      ge/offline/single_op_parser.cc
  3. +8
    -8
      inc/framework/generator/ge_generator.h
  4. +1
    -1
      metadef
  5. +1
    -1
      parser

+ 29
- 3
ge/generator/ge_generator.cc View File

@@ -528,9 +528,19 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
return SUCCESS; return SUCCESS;
} }


Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline) {
namespace {
bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) {
bool is_need = true;
// format and dtype is all reserved, stand for Optional input. When singleop scene
if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
is_need = false;
}
return is_need;
}
}

Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
const vector<GeTensor> &outputs) {
GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) { if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size: %zu", inputs.size(), op_desc->GetAllInputsSize()); GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size: %zu", inputs.size(), op_desc->GetAllInputsSize());
@@ -540,7 +550,17 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
GELOGE(PARAM_INVALID, "Tensor size: %zu, Outputs size: %zu", outputs.size(), op_desc->GetOutputsSize()); GELOGE(PARAM_INVALID, "Tensor size: %zu, Outputs size: %zu", outputs.size(), op_desc->GetOutputsSize());
return PARAM_INVALID; return PARAM_INVALID;
} }
return SUCCESS;
}

Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline) {


if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
GELOGE(PARAM_INVALID, "input param is invalid when build single op!");
return PARAM_INVALID;
}
OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_; OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_;
omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc); omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);


@@ -575,12 +595,18 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
if (inputs.empty()) { if (inputs.empty()) {
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
continue;
}
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false));
arg_index++; arg_index++;
} }
} else { } else {
for (const auto &in_desc : inputs) { for (const auto &in_desc : inputs) {
GeTensorDesc input_desc = in_desc.GetTensorDesc(); GeTensorDesc input_desc = in_desc.GetTensorDesc();
if (!IsNeedConnectInputOpForSingleOp(input_desc)) {
continue;
}
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true));
arg_index++; arg_index++;
} }


+ 5
- 10
ge/offline/single_op_parser.cc View File

@@ -226,16 +226,11 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {
} }


int index = 0; int index = 0;
for (auto &tensor_desc : op_desc.input_desc) {
if (tensor_desc.type == DT_UNDEFINED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)});
GELOGE(false, "Input's dataType is invalid when the index is %d", index);
return false;
}

if (tensor_desc.format == FORMAT_RESERVED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index);
for (auto &tensor_desc : op_desc.output_desc) {
if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) ||
(tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index);
return false; return false;
} }
++index; ++index;


+ 8
- 8
inc/framework/generator/ge_generator.h View File

@@ -53,7 +53,7 @@ class GeGenerator {
Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix,
const std::vector<GeTensor> &inputs = std::vector<GeTensor>()); const std::vector<GeTensor> &inputs = std::vector<GeTensor>());


Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData& model);
Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model);


Status GenerateInfershapeGraph(const Graph &graph); Status GenerateInfershapeGraph(const Graph &graph);


@@ -77,16 +77,16 @@ class GeGenerator {
/// @param [in] engine_type: specific engine. /// @param [in] engine_type: specific engine.
/// @param [out] model_buff: model buff of single op. /// @param [out] model_buff: model buff of single op.
/// @return SUCCESS or FAILED /// @return SUCCESS or FAILED
Status BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
const vector<GeTensor> &outputs, OpEngineType engine_type,
ModelBufferData &model_buff);
Status BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
OpEngineType engine_type, ModelBufferData &model_buff);


private: private:
Status GenerateModel(const Graph &graph, const string &file_name_prefix,
const vector<GeTensor> &inputs, ge::ModelBufferData& model, bool is_offline = true);
Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
ge::ModelBufferData &model, bool is_offline = true);
Status BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, Status BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type,
ModelBufferData &model_buff, bool is_offline = true);
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline = true);
Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);


class Impl; class Impl;




+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit ae80e9e2369458d468c16fd95bb970922ff3a084
Subproject commit 1cc55bcae09902b3d158993dd57bfbd1d3337066

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit c1530c6083ea18c6c6c4c14b08253830e3982344
Subproject commit db4e6070bb2cec01cead264a44ceae07e7f3048e

Loading…
Cancel
Save