From 44e79db31d58f59c9d6e0aaf04bbd4d975bf930a Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Thu, 8 Apr 2021 16:24:52 +0800 Subject: [PATCH] modify CheckInputShape func --- ge/hybrid/executor/hybrid_model_executor.cc | 18 +++++++++++------- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 5 +++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index 4a8a0af0..6addd9b5 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -175,19 +175,16 @@ Status HybridModelExecutor::CheckInputShapeByShapeRange(const GraphItem *graph_i HybridModelExecutor::ExecuteArgs &args) { GE_CHECK_NOTNULL(graph_item); auto input_nodes = graph_item->GetInputNodes(); - if (args.input_desc.size() < input_nodes.size()) { - REPORT_INNER_ERROR("E19999", "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.", - graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size()); - GELOGE(INTERNAL_ERROR, "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.", - graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size()); - return INTERNAL_ERROR; - } for (size_t i = 0; i < input_nodes.size(); ++i) { auto &input_node = input_nodes[i]; if (input_node == nullptr) { GELOGD("[%s] Input[%zu] is not needed by graph, skip it.", graph_item->GetName().c_str(), i); continue; } + if (!input_node->is_dynamic) { + GELOGD("[%s] Input[%zu] is not dynamic, skip it.", graph_item->GetName().c_str(), i); + continue; + } GeTensorDescPtr model_input_desc = input_node->MutableInputDesc(0); GE_CHECK_NOTNULL(model_input_desc); std::vector> shape_range; @@ -200,6 +197,13 @@ Status HybridModelExecutor::CheckInputShapeByShapeRange(const GraphItem *graph_i GELOGD("[%s] Input[%zu] shape is not needed to check by shape range, skip it.", graph_item->GetName().c_str(), i); continue; } + if (i >= args.input_desc.size()) { + REPORT_INNER_ERROR("E19999", "[%s] Inputs[%zu] is greater than or equal to input desc size[%zu].", + graph_item->GetName().c_str(), i, args.input_desc.size()); + GELOGE(INTERNAL_ERROR, "[%s] inputs[%zu] is greater than or equal to input desc size[%zu].", + graph_item->GetName().c_str(), i, args.input_desc.size()); + return INTERNAL_ERROR; + } ConstGeTensorDescPtr args_tensor_desc = args.input_desc[i]; GE_CHECK_NOTNULL(args_tensor_desc); GeShape shape = args_tensor_desc->GetShape(); diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 9746585d..95952271 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -480,6 +480,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_check_shape) { NodePtr node = graph->AddNode(op_desc); std::unique_ptr new_node; NodeItem::Create(node, new_node); + new_node->is_dynamic = true; GraphItem graph_item; graph_item.input_nodes_.emplace_back(new_node.get()); @@ -499,6 +500,10 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_check_shape) { ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args1); ASSERT_EQ(ret, ge::INTERNAL_ERROR); + + HybridModelExecutor::ExecuteArgs args3; + ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args3); + ASSERT_EQ(ret, ge::INTERNAL_ERROR); } TEST_F(UtestGeHybrid, TestOptimizeDependenciesForConstInputs) {