|
|
@@ -92,8 +92,7 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { |
|
|
|
} |
|
|
|
|
|
|
|
// parser data dynamic info from atc parameter --input_shape |
|
|
|
if (multibatch::ParserDataToDynmaicInfo(batch_shapes_, GetLocalOmgContext().user_input_dims, |
|
|
|
data_to_dynamic_info_) != SUCCESS) { |
|
|
|
if (CheckAndParseDynamicData() != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "Parse each data's own dynamic info failed"); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
@@ -177,6 +176,59 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::CheckAndParseDynamicData(){ |
|
|
|
size_t unknown_shape_count = 0; |
|
|
|
auto data_name_and_shape = GetLocalOmgContext().user_input_dims; |
|
|
|
std::vector<std::string> data_name_order; |
|
|
|
for (auto &item : GetLocalOmgContext().user_input_dims) { |
|
|
|
data_name_order.push_back(item.first); |
|
|
|
} |
|
|
|
if (!getnext_sink_dynamic_dims_) { |
|
|
|
for (const auto &node : all_data_nodes_) { |
|
|
|
auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex); |
|
|
|
auto data_shape = data_desc.GetShape(); |
|
|
|
auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" : |
|
|
|
data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others"; |
|
|
|
auto data_name = node->GetName(); |
|
|
|
GELOGI("CheckAndParseDynamicData shape_dims is %s.", formats::JoinToString(data_shape.GetDims()).c_str()); |
|
|
|
|
|
|
|
std::vector<int64_t> data_shape_dims = data_shape.GetDims(); |
|
|
|
if (std::all_of(data_shape_dims.begin(), data_shape_dims.end(), [](int64_t val) { return val >= 0; })) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
++unknown_shape_count; |
|
|
|
auto iter = find(data_name_order.begin(), data_name_order.end(), data_name); |
|
|
|
if (iter == data_name_order.end()) { |
|
|
|
if (!GetLocalOmgContext().dynamic_batch_size.empty()) { |
|
|
|
auto ret = multibatch::CheckDynamicBatchShape(data_shape_dims, data_name); |
|
|
|
GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic batch shape of %s.", |
|
|
|
data_name.c_str()); return PARAM_INVALID); |
|
|
|
} else if (!GetLocalOmgContext().dynamic_image_size.empty()) { |
|
|
|
auto ret = multibatch::CheckDynamicImageSizeShape(data_shape_dims, data_name, data_format); |
|
|
|
GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic image size shape of %s.", |
|
|
|
data_name.c_str()); return PARAM_INVALID); |
|
|
|
} else if (!GetLocalOmgContext().dynamic_dims.empty()) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10001",{"parameter", "reason"}, |
|
|
|
{"--input_shape","all dynamic data must be set in --input_shape"}); |
|
|
|
GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape", |
|
|
|
node->GetName().c_str(), data_shape.ToString().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
data_name_and_shape.emplace_back(data_name, data_shape_dims); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
auto ret = multibatch::ParserDataToDynmaicInfo(batch_shapes_, data_name_and_shape, data_to_dynamic_info_); |
|
|
|
GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info."); |
|
|
|
if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10040"); |
|
|
|
GELOGE(PARAM_INVALID, |
|
|
|
"Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims"); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) { |
|
|
|
data_count_from_getnext_ = 0; |
|
|
|
getnext_sink_dynamic_dims_ = false; |
|
|
|