Browse Source

FillKernel

tags/v1.5.1
lianghao 3 years ago
parent
commit
55a5e8019d
2 changed files with 44 additions and 0 deletions
  1. +8
    -0
      ge/host_kernels/fill_kernel.cc
  2. +36
    -0
      tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc

+ 8
- 0
ge/host_kernels/fill_kernel.cc View File

@@ -45,6 +45,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge
GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr.");
return PARAM_INVALID; return PARAM_INVALID;
} }
GELOGD("FillKernel in, name: %s.", op_desc_ptr->GetName().c_str());


GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex)); GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex));
GE_CHECK_NOTNULL(input.at(kFillDataInputIndex)); GE_CHECK_NOTNULL(input.at(kFillDataInputIndex));
@@ -57,6 +58,13 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge
return NOT_CHANGED; return NOT_CHANGED;
} }


auto output_desc = op_desc_ptr->GetOutputDescPtr(0);
GE_CHECK_NOTNULL(output_desc);
if (output_desc->GetShape().IsUnknownShape()) {
GELOGD("Output is unknown shape, [%s] skip FillKernel.", op_desc_ptr->GetName().c_str());
return NOT_CHANGED;
}

GeTensorPtr output_ptr; GeTensorPtr output_ptr;
output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0)); output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
if (output_ptr == nullptr) { if (output_ptr == nullptr) {


+ 36
- 0
tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc View File

@@ -64,6 +64,7 @@ class UtestGraphPassesFoldingKernelFillKernel : public testing::Test {


op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(dims_tensor_desc);
op_desc_ptr->AddInputDesc(value_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc);
op_desc_ptr->AddOutputDesc(dims_tensor_desc);


std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor};
std::vector<GeTensorPtr> outputs; std::vector<GeTensorPtr> outputs;
@@ -124,6 +125,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillBoolShape2And3) {


op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(dims_tensor_desc);
op_desc_ptr->AddInputDesc(value_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc);
op_desc_ptr->AddOutputDesc(dims_tensor_desc);


std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor};
std::vector<GeTensorPtr> outputs; std::vector<GeTensorPtr> outputs;
@@ -230,6 +232,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsHaveNegativeNumber) {


op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(dims_tensor_desc);
op_desc_ptr->AddInputDesc(value_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc);
op_desc_ptr->AddOutputDesc(dims_tensor_desc);


std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor};
std::vector<GeTensorPtr> outputs; std::vector<GeTensorPtr> outputs;
@@ -284,6 +287,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsTypeNotSupport) {


op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(dims_tensor_desc);
op_desc_ptr->AddInputDesc(value_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc);
op_desc_ptr->AddOutputDesc(dims_tensor_desc);


std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor};
std::vector<GeTensorPtr> outputs; std::vector<GeTensorPtr> outputs;
@@ -310,6 +314,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsOverflow) {


op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(dims_tensor_desc);
op_desc_ptr->AddInputDesc(value_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc);
op_desc_ptr->AddOutputDesc(dims_tensor_desc);


std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor};
std::vector<GeTensorPtr> outputs; std::vector<GeTensorPtr> outputs;
@@ -336,6 +341,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) {


op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(dims_tensor_desc);
op_desc_ptr->AddInputDesc(value_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc);
op_desc_ptr->AddOutputDesc(dims_tensor_desc);


std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor};
std::vector<GeTensorPtr> outputs; std::vector<GeTensorPtr> outputs;
@@ -343,3 +349,33 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) {


EXPECT_EQ(PARAM_INVALID, status); EXPECT_EQ(PARAM_INVALID, status);
} }

TEST_F(UtestGraphPassesFoldingKernelFillKernel, OutputdescUnknown) {
ge::OpDescPtr op_dims = std::make_shared<ge::OpDesc>();
vector <int64_t> dims_vec = {2};
vector <int32_t> dims_value_vec = {2, 3};
GeTensorDesc dims_tensor_desc(GeShape(dims_vec), FORMAT_NCHW, DT_INT32);
GeTensorPtr dim_tensor = std::make_shared<GeTensor>(dims_tensor_desc, (uint8_t *) dims_value_vec.data(),
dims_value_vec.size() * sizeof(int32_t));
OpDescUtils::SetWeights(op_dims, dim_tensor);

ge::OpDescPtr op_value = std::make_shared<ge::OpDesc>();
vector <uint8_t> data_vec = {1};
GeTensorDesc value_tensor_desc(GeShape(), FORMAT_NCHW, DT_BOOL);
GeTensorPtr value_tensor =
std::make_shared<GeTensor>(value_tensor_desc, (uint8_t *) data_vec.data(), data_vec.size() * sizeof(bool));
OpDescUtils::SetWeights(op_value, value_tensor);

op_desc_ptr->AddInputDesc(dims_tensor_desc);
op_desc_ptr->AddInputDesc(value_tensor_desc);

vector <int64_t> out_vec = {-1, -1};
GeTensorDesc out_tensor_desc(GeShape(out_vec), FORMAT_NCHW, DT_INT32);
op_desc_ptr->AddOutputDesc(out_tensor_desc);

std::vector <ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor};
std::vector <GeTensorPtr> outputs;
Status status = kernel->Compute(op_desc_ptr, input, outputs);

EXPECT_EQ(NOT_CHANGED, status);
}

Loading…
Cancel
Save