Browse Source

modify CheckInputShape func

tags/v1.3.0
zhengyuanhua 3 years ago
parent
commit
44e79db31d
2 changed files with 16 additions and 7 deletions
  1. +11
    -7
      ge/hybrid/executor/hybrid_model_executor.cc
  2. +5
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 11
- 7
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -175,19 +175,16 @@ Status HybridModelExecutor::CheckInputShapeByShapeRange(const GraphItem *graph_i
HybridModelExecutor::ExecuteArgs &args) {
GE_CHECK_NOTNULL(graph_item);
auto input_nodes = graph_item->GetInputNodes();
if (args.input_desc.size() < input_nodes.size()) {
REPORT_INNER_ERROR("E19999", "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.",
graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size());
GELOGE(INTERNAL_ERROR, "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.",
graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size());
return INTERNAL_ERROR;
}
for (size_t i = 0; i < input_nodes.size(); ++i) {
auto &input_node = input_nodes[i];
if (input_node == nullptr) {
GELOGD("[%s] Input[%zu] is not needed by graph, skip it.", graph_item->GetName().c_str(), i);
continue;
}
if (!input_node->is_dynamic) {
GELOGD("[%s] Input[%zu] is not dynamic, skip it.", graph_item->GetName().c_str(), i);
continue;
}
GeTensorDescPtr model_input_desc = input_node->MutableInputDesc(0);
GE_CHECK_NOTNULL(model_input_desc);
std::vector<std::pair<int64_t, int64_t>> shape_range;
@@ -200,6 +197,13 @@ Status HybridModelExecutor::CheckInputShapeByShapeRange(const GraphItem *graph_i
GELOGD("[%s] Input[%zu] shape is not needed to check by shape range, skip it.", graph_item->GetName().c_str(), i);
continue;
}
if (i >= args.input_desc.size()) {
REPORT_INNER_ERROR("E19999", "[%s] Inputs[%zu] is greater than or equal to input desc size[%zu].",
graph_item->GetName().c_str(), i, args.input_desc.size());
GELOGE(INTERNAL_ERROR, "[%s] inputs[%zu] is greater than or equal to input desc size[%zu].",
graph_item->GetName().c_str(), i, args.input_desc.size());
return INTERNAL_ERROR;
}
ConstGeTensorDescPtr args_tensor_desc = args.input_desc[i];
GE_CHECK_NOTNULL(args_tensor_desc);
GeShape shape = args_tensor_desc->GetShape();


+ 5
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -480,6 +480,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_check_shape) {
NodePtr node = graph->AddNode(op_desc);
std::unique_ptr<NodeItem> new_node;
NodeItem::Create(node, new_node);
new_node->is_dynamic = true;

GraphItem graph_item;
graph_item.input_nodes_.emplace_back(new_node.get());
@@ -499,6 +500,10 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_check_shape) {

ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args1);
ASSERT_EQ(ret, ge::INTERNAL_ERROR);

HybridModelExecutor::ExecuteArgs args3;
ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args3);
ASSERT_EQ(ret, ge::INTERNAL_ERROR);
}

TEST_F(UtestGeHybrid, TestOptimizeDependenciesForConstInputs) {


Loading…
Cancel
Save