From d9c98b87153f589068b77214d6481deed2a6f5d3 Mon Sep 17 00:00:00 2001 From: wxl Date: Sat, 24 Oct 2020 17:35:38 +0800 Subject: [PATCH] Feature:atc single op support optional input! --- ge/generator/ge_generator.cc | 32 +++++++++++++++++++++++--- ge/offline/single_op_parser.cc | 15 ++++-------- inc/framework/generator/ge_generator.h | 16 ++++++------- metadef | 2 +- parser | 2 +- 5 files changed, 44 insertions(+), 23 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index ad3084dc..f60561c7 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -528,9 +528,19 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr return SUCCESS; } -Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, - const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, - bool is_offline) { +namespace { + bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) { + bool is_need = true; + // format and dtype is all reserved, stand for Optional input. When singleop scene + if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) { + is_need = false; + } + return is_need; + } +} + +Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, + const vector &outputs) { GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) { GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size: %zu", inputs.size(), op_desc->GetAllInputsSize()); @@ -540,7 +550,17 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in GELOGE(PARAM_INVALID, "Tensor size: %zu, Outputs size: %zu", outputs.size(), op_desc->GetOutputsSize()); return PARAM_INVALID; } + return SUCCESS; +} + +Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, + const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, + bool is_offline) { + if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) { + GELOGE(PARAM_INVALID, "input param is invalid when build single op!"); + return PARAM_INVALID; + } OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_; omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc); @@ -575,12 +595,18 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in if (inputs.empty()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); + if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { + continue; + } GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); arg_index++; } } else { for (const auto &in_desc : inputs) { GeTensorDesc input_desc = in_desc.GetTensorDesc(); + if (!IsNeedConnectInputOpForSingleOp(input_desc)) { + continue; + } GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); arg_index++; } diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index df75e21d..8a86f5c5 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -226,16 +226,11 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { } int index = 0; - for (auto &tensor_desc : op_desc.input_desc) { - if (tensor_desc.type == DT_UNDEFINED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(false, "Input's dataType is invalid when the index is %d", index); - return false; - } - - if (tensor_desc.format == FORMAT_RESERVED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); + for (auto &tensor_desc : op_desc.output_desc) { + if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) || + (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){ + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index); return false; } ++index; diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index 4902a021..c446b983 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -53,7 +53,7 @@ class GeGenerator { Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, const std::vector &inputs = std::vector()); - Status GenerateOnlineModel(const Graph &graph, const vector &inputs, ge::ModelBufferData& model); + Status GenerateOnlineModel(const Graph &graph, const vector &inputs, ge::ModelBufferData &model); Status GenerateInfershapeGraph(const Graph &graph); @@ -77,16 +77,16 @@ class GeGenerator { /// @param [in] engine_type: specific engine. /// @param [out] model_buff: model buff of single op. /// @return SUCCESS or FAILED - Status BuildSingleOpModel(OpDescPtr &op_desc, const vector &inputs, - const vector &outputs, OpEngineType engine_type, - ModelBufferData &model_buff); + Status BuildSingleOpModel(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, + OpEngineType engine_type, ModelBufferData &model_buff); private: - Status GenerateModel(const Graph &graph, const string &file_name_prefix, - const vector &inputs, ge::ModelBufferData& model, bool is_offline = true); + Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, + ge::ModelBufferData &model, bool is_offline = true); Status BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, - const string &model_file_name, OpEngineType engine_type, - ModelBufferData &model_buff, bool is_offline = true); + const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, + bool is_offline = true); + Status CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs); class Impl; diff --git a/metadef b/metadef index ae80e9e2..1cc55bca 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit ae80e9e2369458d468c16fd95bb970922ff3a084 +Subproject commit 1cc55bcae09902b3d158993dd57bfbd1d3337066 diff --git a/parser b/parser index c1530c60..db4e6070 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit c1530c6083ea18c6c6c4c14b08253830e3982344 +Subproject commit db4e6070bb2cec01cead264a44ceae07e7f3048e