Browse Source

checkout input user dim modify

tags/v1.2.0
zhengyuanhua 3 years ago
parent
commit
ed6a811c15
2 changed files with 22 additions and 7 deletions
  1. +8
    -7
      ge/graph/preprocess/graph_preprocess.cc
  2. +14
    -0
      tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc

+ 8
- 7
ge/graph/preprocess/graph_preprocess.cc View File

@@ -23,6 +23,7 @@
#include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h"
#include "common/formats/format_transfers/format_transfer_transpose.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "common/util/error_manager/error_manager.h"
#include "common/helper/model_helper.h"
#include "common/math/math_util.h"
#include "common/op/ge_op_utils.h"
@@ -1763,13 +1764,13 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) {
GeTensorDesc desc(user_input[index].GetTensorDesc());

for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) {
if (desc.GetShape().GetDim(i) < 0) {
std::string situation = "data dim[" + std::to_string(i) + "][" +
std::to_string(desc.GetShape().GetDim(i)) + "]" ;
std::string reason = "it need >= 0";
ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason});
GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i,
desc.GetShape().GetDim(i));
int64_t dim = desc.GetShape().GetDim(i);
if (dim < UNKNOWN_DIM_NUM) {
std::string situation = "data dim[" + std::to_string(i) + "][" + std::to_string(dim) + "]" ;
std::string reason = "it need >= -2";
REPORT_INPUT_ERROR(
"E19025", std::vector<std::string>({"situation", "reason"}),std::vector<std::string>({situation, reason}));
GELOGE(GE_GRAPH_INIT_FAILED, "[Check][InputDim]data dim %zu is not supported, need >= -2, real:%ld.", i, dim);
return GE_GRAPH_INIT_FAILED;
}
}


+ 14
- 0
tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc View File

@@ -74,4 +74,18 @@ TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) {
EXPECT_EQ(result_shape.GetDim(i), expect_shape.at(i));
}
}

TEST_F(UtestGraphPreproces, test_check_user_input) {
ge::GraphPrepare graph_prepare;
graph_prepare.compute_graph_ = BuildGraph1();

vector<int64_t> dim = {2, -3};
GeTensor tensor;
tensor.SetTensorDesc(GeTensorDesc(GeShape(dim)));
std::vector<GeTensor> user_input;
user_input.emplace_back(tensor);

Status ret = graph_prepare.CheckUserInput(user_input);
EXPECT_EQ(ret, GE_GRAPH_INIT_FAILED);
}
}

Loading…
Cancel
Save