Browse Source

Feature: support optional input for single op

pull/127/head
wxl 5 years ago
parent
commit
4522859212
2 changed files with 21 additions and 10 deletions
  1. +16
    -0
      ge/generator/ge_generator.cc
  2. +5
    -10
      ge/offline/single_op_parser.cc

+ 16
- 0
ge/generator/ge_generator.cc View File

@@ -528,6 +528,16 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
return SUCCESS; return SUCCESS;
} }


namespace {
bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tesor_desc) {
bool is_need = true;
// format and dtype is all reserved, stand for Optional input. When singleop scene
if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
is_need = false;
}
return is_need;
}
}
Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline) { bool is_offline) {
@@ -575,12 +585,18 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
if (inputs.empty()) { if (inputs.empty()) {
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
continue;
}
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false));
arg_index++; arg_index++;
} }
} else { } else {
for (const auto &in_desc : inputs) { for (const auto &in_desc : inputs) {
GeTensorDesc input_desc = in_desc.GetTensorDesc(); GeTensorDesc input_desc = in_desc.GetTensorDesc();
if (!IsNeedConnectInputOpForSingleOp(input_desc)) {
continue;
}
GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true));
arg_index++; arg_index++;
} }


+ 5
- 10
ge/offline/single_op_parser.cc View File

@@ -226,16 +226,11 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {
} }


int index = 0; int index = 0;
for (auto &tensor_desc : op_desc.input_desc) {
if (tensor_desc.type == DT_UNDEFINED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)});
GELOGE(false, "Input's dataType is invalid when the index is %d", index);
return false;
}

if (tensor_desc.format == FORMAT_RESERVED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index);
for (auto &tensor_desc : op_desc.output_desc) {
if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) ||
(tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index);
return false; return false;
} }
++index; ++index;


Loading…
Cancel
Save