Browse Source

set scalar tensor value range

tags/v1.5.1
wq160 3 years ago
parent
commit
59c97eef3e
4 changed files with 61 additions and 6 deletions
  1. +3
    -0
      ge/graph/passes/infer_value_range_pass.cc
  2. +9
    -5
      ge/graph/passes/replace_with_empty_const_pass.cc
  3. +1
    -1
      ge/graph/passes/replace_with_empty_const_pass.h
  4. +48
    -0
      tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc

+ 3
- 0
ge/graph/passes/infer_value_range_pass.cc View File

@@ -286,6 +286,9 @@ graphStatus InferValueRangePass::GenerateWorstValueRange(NodePtr &node) {
}

std::vector<std::pair<int64_t, int64_t>> output_i_value_range(output_i_shape_size, {1, -1});
if (output_i_shape.IsScalar()) {
output_i_value_range.emplace_back(1, -1);
}
output_desc->SetValueRange(output_i_value_range);
GELOGD("Node %s output %zu shape is %s, the generated worst value range is %s.", node->GetName().c_str(), i,
formats::ShapeToString(output_i_shape).c_str(), formats::RangeToString(output_i_value_range).c_str());


+ 9
- 5
ge/graph/passes/replace_with_empty_const_pass.cc View File

@@ -71,7 +71,7 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGI("Node %s Got empty output_desc_ptr, ignore current pass.", node->GetName().c_str());
return SUCCESS;
}
if (!IsEmptyTenor(output_desc_ptr->GetShape())) {
if (!IsKnownEmptyTenor(output_desc_ptr->GetShape())) {
is_all_output_empty = false;
break;
}
@@ -107,12 +107,16 @@ Status ReplaceWithEmptyConstPass::GetOutputsOfCurrNode(const NodePtr &node_to_re
return SUCCESS;
}

bool ReplaceWithEmptyConstPass::IsEmptyTenor(const GeShape &shape) const {
bool ReplaceWithEmptyConstPass::IsKnownEmptyTenor(const GeShape &shape) const {
bool is_known_empty_tensor = false;
for (auto dim : shape.GetDims()) {
if (dim == 0) {
return true;
if (dim < 0) {
// current dim is unknown dim, skip replace
return false;
} else if (dim == 0) {
is_known_empty_tensor = true;
}
}
return false;
return is_known_empty_tensor;
}
} // namespace ge

+ 1
- 1
ge/graph/passes/replace_with_empty_const_pass.h View File

@@ -26,7 +26,7 @@ class ReplaceWithEmptyConstPass : public FoldingPass {

private:
Status GetOutputsOfCurrNode(const NodePtr &node_to_replace, vector<GeTensorPtr> &outputs);
bool IsEmptyTenor(const GeShape &shape) const;
bool IsKnownEmptyTenor(const GeShape &shape) const;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_

+ 48
- 0
tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc View File

@@ -362,6 +362,54 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave
EXPECT_EQ(unknown_target_value_range, output_value_range);
}

TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveUnKnownValueRange_ScalarOutput) {
// shape --- add --- sqrt
// constant /
auto graph = std::make_shared<ComputeGraph>("test_graph");
vector<int64_t> data_vec = {1};
GeTensorDesc const_tensor_desc(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
GeTensorPtr const_tensor =
std::make_shared<ge::GeTensor>(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t));

auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant");
const_op_desc->AddOutputDesc(const_tensor_desc);
EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS);
auto const_node = graph->AddNode(const_op_desc);

GeTensorDesc shape_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
std::vector<std::pair<int64_t, int64_t>> unknown_value_range = {make_pair(1, -1)};
shape_tensor_desc.SetValueRange(unknown_value_range);
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape");
shape_op_desc->AddOutputDesc(shape_tensor_desc);
auto shape_node = graph->AddNode(shape_op_desc);

GeTensorDesc add_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add");
add_op_desc->AddInputDesc(shape_tensor_desc);
add_op_desc->AddInputDesc(const_tensor_desc);
add_op_desc->AddOutputDesc(add_tensor_desc);
auto add_node = graph->AddNode(add_op_desc);

ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1));

// test unknown value range
InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS);
auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> out_value_range;
output_0_desc.GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 1);

std::vector<int64_t> unknown_target_value_range = {1, -1};
std::vector<int64_t> output_value_range;
for (auto pair : out_value_range) {
output_value_range.push_back(pair.first);
output_value_range.push_back(pair.second);
}
EXPECT_EQ(unknown_target_value_range, output_value_range);
}

TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) {
// shape --- add --- sqrt
// constant /


Loading…
Cancel
Save