Merge pull request !1952 from 王强/mastertags/v1.5.1
@@ -85,8 +85,16 @@ graphStatus InferValueRangePass::Infer(NodePtr &node) { | |||
return GRAPH_SUCCESS; | |||
} | |||
// if input value range has -1, cpu kernel cannot calculate correctly, so set {1:-1} | |||
if (InputHasUnknownValueRange(node)) { | |||
// Deal with scenes with unknown value range | |||
bool has_unknown_value_range = false; | |||
bool has_zero_in_value_range = false; | |||
CheckInputValueRange(node, has_unknown_value_range, has_zero_in_value_range); | |||
if (has_unknown_value_range) { | |||
if (has_zero_in_value_range) { | |||
// When there is zero in input value range, it is unreasonable to always set output value range {1:-1}. | |||
GELOGW("Node %s has -1 and 0 in value range, skip setting value range.", node->GetName().c_str()); | |||
return GRAPH_NOT_CHANGED; | |||
} | |||
GELOGI("Node %s has unknown value range in input tensors, set value range {1:-1}, and skip cpu kernel.", | |||
node->GetName().c_str()); | |||
return GenerateWorstValueRange(node); | |||
@@ -188,14 +196,21 @@ bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) const | |||
return input_is_const_or_has_value_range; | |||
} | |||
bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const { | |||
bool has_unknown_value_range = false; | |||
void InferValueRangePass::CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range, | |||
bool &has_zero_in_value_range) const { | |||
has_unknown_value_range = false; | |||
has_zero_in_value_range = false; | |||
auto cur_op_desc = node->GetOpDesc(); | |||
for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { | |||
std::vector<std::pair<int64_t, int64_t>> input_desc_value_range; | |||
input_desc->GetValueRange(input_desc_value_range); | |||
if (!input_desc_value_range.empty()) { | |||
for (const auto &range : input_desc_value_range) { | |||
if (range.first == 0 || range.second == 0) { | |||
GELOGD("Node %s input tensors have zero in value range %s.", node->GetName().c_str(), | |||
formats::RangeToString(input_desc_value_range).c_str()); | |||
has_zero_in_value_range = true; | |||
} | |||
if (range.first == -1 || range.second == -1) { | |||
GELOGD("Node %s input tensors have unknown value range, value range is %s.", node->GetName().c_str(), | |||
formats::RangeToString(input_desc_value_range).c_str()); | |||
@@ -204,7 +219,6 @@ bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const { | |||
} | |||
} | |||
} | |||
return has_unknown_value_range; | |||
} | |||
graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | |||
@@ -34,7 +34,7 @@ class InferValueRangePass : public InferBasePass { | |||
bool InputIsDynamic(const NodePtr &node) const; | |||
bool InputIsConstOrHasValueRange(const NodePtr &node) const; | |||
bool InputHasUnknownValueRange(const NodePtr &node) const; | |||
void CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range, bool &has_zero_in_value_range) const; | |||
graphStatus GenerateWorstValueRange(NodePtr &node); | |||
template <typename T> | |||
graphStatus ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); | |||
@@ -365,6 +365,35 @@ TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHave | |||
EXPECT_EQ(unknown_target_value_range, output_value_range); | |||
} | |||
TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveZeroInValueRange) { | |||
// shape --- add --- sqrt | |||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||
GeTensorDesc shape_tensor_desc(GeShape({2}), ge::FORMAT_NCHW, ge::DT_INT64); | |||
std::vector<std::pair<int64_t, int64_t>> unknown_value_range = {make_pair(1, -1), make_pair(0, 240)}; | |||
shape_tensor_desc.SetValueRange(unknown_value_range); | |||
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||
shape_op_desc->AddOutputDesc(shape_tensor_desc); | |||
auto shape_node = graph->AddNode(shape_op_desc); | |||
GeTensorDesc add_tensor_desc(GeShape({2}), ge::FORMAT_NCHW, ge::DT_INT64); | |||
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add"); | |||
add_op_desc->AddInputDesc(shape_tensor_desc); | |||
add_op_desc->AddInputDesc(shape_tensor_desc); | |||
add_op_desc->AddOutputDesc(add_tensor_desc); | |||
auto add_node = graph->AddNode(add_op_desc); | |||
ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); | |||
ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); | |||
// test unknown value range | |||
InferValueRangePass infer_pass; | |||
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); | |||
auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); | |||
std::vector<std::pair<int64_t, int64_t>> out_value_range; | |||
output_0_desc.GetValueRange(out_value_range); | |||
EXPECT_EQ(out_value_range.size(), 0); | |||
} | |||
TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveUnKnownValueRange_ScalarOutput) { | |||
// shape --- add --- sqrt | |||
// constant / | |||