From b369c392bfb1dcea624ead7daca3d85a2702a8c4 Mon Sep 17 00:00:00 2001 From: wq160 Date: Tue, 15 Jun 2021 15:18:19 +0800 Subject: [PATCH] add infer_base_pass and infer_value_range_pass --- ge/CMakeLists.txt | 4 + .../formats/utils/formats_trans_utils.cc | 13 + ge/common/formats/utils/formats_trans_utils.h | 2 + ge/graph/passes/constant_folding_pass.cc | 26 +- ge/graph/passes/constant_folding_pass.h | 5 + ge/graph/passes/folding_pass.cc | 8 - ge/graph/passes/folding_pass.h | 2 - ge/graph/passes/infer_base_pass.cc | 585 +++++++++++++ ge/graph/passes/infer_base_pass.h | 50 ++ ge/graph/passes/infer_value_range_pass.cc | 383 ++++++++ ge/graph/passes/infer_value_range_pass.h | 45 + ge/graph/preprocess/graph_preprocess.cc | 3 + tests/ut/ge/CMakeLists.txt | 5 + .../passes/infer_value_range_pass_unittest.cc | 816 ++++++++++++++++++ 14 files changed, 1926 insertions(+), 21 deletions(-) create mode 100644 ge/graph/passes/infer_base_pass.cc create mode 100644 ge/graph/passes/infer_base_pass.h create mode 100644 ge/graph/passes/infer_value_range_pass.cc create mode 100644 ge/graph/passes/infer_value_range_pass.h create mode 100644 tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index fb5b2ef6..e1ed0f8f 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -298,7 +298,9 @@ set(TRAIN_SRC_LIST "graph/passes/hccl_continuous_memcpy_pass.cc" "graph/passes/identity_pass.cc" "graph/passes/ref_identity_delete_op_pass.cc" + "graph/passes/infer_base_pass.cc" "graph/passes/infershape_pass.cc" + "graph/passes/infer_value_range_pass.cc" "graph/passes/iterator_op_pass.cc" "graph/passes/link_gen_mask_nodes_pass.cc" "graph/passes/merge_pass.cc" @@ -553,7 +555,9 @@ set(INFER_SRC_LIST "graph/passes/shape_operate_op_remove_pass.cc" "graph/passes/assert_pass.cc" "graph/passes/dropout_pass.cc" + "graph/passes/infer_base_pass.cc" "graph/passes/infershape_pass.cc" + "graph/passes/infer_value_range_pass.cc" "graph/passes/unused_const_pass.cc" "graph/passes/permute_pass.cc" "graph/passes/ctrl_edge_transfer_pass.cc" diff --git a/ge/common/formats/utils/formats_trans_utils.cc b/ge/common/formats/utils/formats_trans_utils.cc index 052951ce..65cfcd19 100755 --- a/ge/common/formats/utils/formats_trans_utils.cc +++ b/ge/common/formats/utils/formats_trans_utils.cc @@ -49,6 +49,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const s return JoinToString(shape); } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +std::string RangeToString(const std::vector> &range) { + string serial_string; + serial_string += "["; + for (const auto &pair : range) { + serial_string += "{"; + serial_string += std::to_string(pair.first) + "," + std::to_string(pair.second); + serial_string += "},"; + } + serial_string += "]"; + return serial_string; +} + int64_t GetItemNumByShape(const std::vector &shape) { int64_t num = 1; for (auto dim : shape) { diff --git a/ge/common/formats/utils/formats_trans_utils.h b/ge/common/formats/utils/formats_trans_utils.h index 848e8b3a..b1384ea4 100755 --- a/ge/common/formats/utils/formats_trans_utils.h +++ b/ge/common/formats/utils/formats_trans_utils.h @@ -54,6 +54,8 @@ std::string ShapeToString(const GeShape &shape); std::string ShapeToString(const std::vector &shape); +std::string RangeToString(const std::vector> &range); + int64_t GetItemNumByShape(const std::vector &shape); bool CheckShapeValid(const std::vector &shape, const int64_t expect_dims); diff --git a/ge/graph/passes/constant_folding_pass.cc b/ge/graph/passes/constant_folding_pass.cc index 25fe26da..3112f378 100644 --- a/ge/graph/passes/constant_folding_pass.cc +++ b/ge/graph/passes/constant_folding_pass.cc @@ -20,17 +20,23 @@ #include "graph/operator_factory.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" +#include "ge_local_engine/engine/host_cpu_engine.h" #include "init/gelib.h" namespace ge { const int64_t kStartCallNum = 1; const std::string kKernelLibName = "aicpu_tf_kernel"; -// tf_kernel.json opsFlag config const std::string kOpsFlagClose = "0"; -Status RunOpKernelWithCheck(NodePtr &node, - const vector &inputs, - std::vector &outputs) { +const map> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { + return statistic_of_ge_constant_folding_; +} +const map> &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { + return statistic_of_op_constant_folding_; +} + +Status ConstantFoldingPass::RunOpKernelWithCheck(NodePtr &node, const vector &inputs, + std::vector &outputs) { std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized or is finalized."); @@ -47,15 +53,13 @@ Status RunOpKernelWithCheck(NodePtr &node, if (ops_flag == kOpsFlagClose) { return UNSUPPORTED; } - return FoldingPass::RunOpKernel(node, inputs, outputs); + return RunOpKernel(node, inputs, outputs); } -const map> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { - return statistic_of_ge_constant_folding_; -} - -const map> &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { - return statistic_of_op_constant_folding_; +Status ConstantFoldingPass::RunOpKernel(NodePtr &node, + const vector &inputs, + std::vector &outputs) { + return HostCpuEngine::GetInstance().Run(node, inputs, outputs); } Status ConstantFoldingPass::Run(ge::NodePtr &node) { diff --git a/ge/graph/passes/constant_folding_pass.h b/ge/graph/passes/constant_folding_pass.h index 703e6edd..7de48a17 100644 --- a/ge/graph/passes/constant_folding_pass.h +++ b/ge/graph/passes/constant_folding_pass.h @@ -28,6 +28,11 @@ class ConstantFoldingPass : public FoldingPass { Status Run(ge::NodePtr &node) override; const std::map> &GetGeConstantFoldingPerfStatistic() const; const std::map> &GetOpConstantFoldingPerfStatistic() const; + + static Status RunOpKernel(NodePtr &node, const vector &inputs, vector &outputs); + static Status RunOpKernelWithCheck(NodePtr &node, const vector &inputs, + std::vector &outputs); + private: std::map> statistic_of_op_constant_folding_; std::map> statistic_of_ge_constant_folding_; diff --git a/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc index c0a0f2a2..819c3b40 100755 --- a/ge/graph/passes/folding_pass.cc +++ b/ge/graph/passes/folding_pass.cc @@ -28,8 +28,6 @@ #include "inc/kernel.h" #include "inc/kernel_factory.h" #include "graph/debug/ge_attr_define.h" -#include "ge_local_engine/engine/host_cpu_engine.h" - namespace ge { namespace folding_pass { @@ -123,12 +121,6 @@ NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tens } } // namespace -Status FoldingPass::RunOpKernel(NodePtr &node, - const vector &inputs, - std::vector &outputs) { - return HostCpuEngine::GetInstance().Run(node, inputs, outputs); -} - Status FoldingPass::Folding(NodePtr &node, vector &outputs) { GE_CHECK_NOTNULL(node); GELOGD("begin folding node:%s", node->GetName().c_str()); diff --git a/ge/graph/passes/folding_pass.h b/ge/graph/passes/folding_pass.h index 745cffd7..c461ff5c 100755 --- a/ge/graph/passes/folding_pass.h +++ b/ge/graph/passes/folding_pass.h @@ -34,8 +34,6 @@ bool IsNoNeedConstantFolding(const NodePtr &node); using IndexsToAnchors = std::map>; class FoldingPass : public BaseNodePass { - public: - static Status RunOpKernel(NodePtr &node, const vector &inputs, vector &outputs); protected: Status Folding(NodePtr &node, vector &outputs); private: diff --git a/ge/graph/passes/infer_base_pass.cc b/ge/graph/passes/infer_base_pass.cc new file mode 100644 index 00000000..4e9f8a29 --- /dev/null +++ b/ge/graph/passes/infer_base_pass.cc @@ -0,0 +1,585 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "infer_base_pass.h" +#include "common/ge/ge_util.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_util.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" + +namespace ge { +namespace { +void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { + std::vector> shape_range; + (void)desc->GetShapeRange(shape_range); + desc_str += formats::RangeToString(shape_range); + shape_range.clear(); + (void)desc->GetOriginShapeRange(shape_range); + desc_str += ","; + desc_str += formats::RangeToString(shape_range); + shape_range.clear(); +} + +graphStatus FindSubgraphDataAndNetoutput(const ComputeGraphPtr &sub_graph, NodePtr &netoutput, const ConstNodePtr &node, + std::vector> &ref_data_tensors) { + auto sub_nodes = sub_graph->GetDirectNode(); + for (size_t i = sub_nodes.size(); i > 0; --i) { + auto sub_node = sub_nodes.at(i - 1); + if (sub_node->GetType() == NETOUTPUT) { + netoutput = sub_node; + } + if (sub_node->GetType() == DATA) { + if (sub_node->GetOpDesc() == nullptr) { + return GRAPH_FAILED; + } + + int ref_i; + if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); + return GRAPH_FAILED; + } + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllInDataAnchorsSize()) { + REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!", + sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); + GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!", + sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); + return GRAPH_FAILED; + } + ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); + } + } + return GRAPH_SUCCESS; +} +} // namespace + +Status InferBasePass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + + bool need_infer = NeedInfer(node); + if (!need_infer) { + GELOGD("Node %s does not need to infer.", node->GetName().c_str()); + return SUCCESS; + } + + std::set changed_nodes; + auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Infer and update for node %s failed! ret: %u", node->GetName().c_str(), ret); + return GRAPH_FAILED; + } + + AddChangedNodesImmediateRepass(changed_nodes); + return SUCCESS; +} + +bool InferBasePass::NeedInfer(const NodePtr &node) { return true; } +void InferBasePass::AddChangedNodesImmediateRepass(const std::set &changed_nodes) { + for (const auto &node_ele : changed_nodes) { + AddImmediateRePassNode(node_ele); + } +} + +graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set &changed_nodes) { + graphStatus ret ; + bool contain_subgraph = ContainsSubgraph(node); + if (contain_subgraph && before_subgraph) { + ret = UpdateTensorDescToSubgraphData(node, changed_nodes); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Update subgraph data tensor desc for node %s failed! ret: %u", node->GetName().c_str(), ret); + return ret; + } + } + ret = Infer(node); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Infer failed for node %s, ret: %u", node->GetName().c_str(), ret); + return ret; + } + if (contain_subgraph && !before_subgraph) { + ret = UpdateTensorDescToParentNode(node, changed_nodes); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Update parent tensor desc for node %s failed! ret: %u", node->GetName().c_str(), ret); + return ret; + } + } + + ret = UpdateTensorDescToPeerInputs(node, changed_nodes); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Node %s updates tensor desc to peer input nodes failed! ret: %u", node->GetName().c_str(), ret); + } + return ret; +} + +bool InferBasePass::ContainsSubgraph(const NodePtr &node) { + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return false; + } + + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + if (root_graph == nullptr) { + return false; + } + for (const auto &name : sub_graph_names) { + if (name.empty()) { + continue; + } + auto sub_graph = root_graph->GetSubgraph(name); + if (sub_graph != nullptr) { + return true; + } + } + return false; +} + +graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set &changed_nodes) { + PrintInOutTensorShape(node, "after_infer"); + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { + auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc(); + if (peer_anchor_opdesc == nullptr) { + continue; + } + auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx()); + if (peer_input_desc == nullptr) { + continue; + } + + bool changed = false; + auto ret = UpdatePeerInputDesc(output_tensor, peer_input_desc, changed); + if (ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update peer input desc failed, node %s.", node->GetName().c_str()); + GELOGE(ret, "Update peer input desc failed, node %s.", node->GetName().c_str()); + return ret; + } + if (changed) { + changed_nodes.insert(peer_anchor->GetOwnerNode()); + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus InferBasePass::UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { + changed = false; + return GRAPH_SUCCESS; +} + +std::vector InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) { + std::vector cur_node_subgraph; + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return cur_node_subgraph; + } + + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + for (const auto &name : sub_graph_names) { + if (name.empty()) { + GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); + continue; + } + auto sub_graph = root_graph->GetSubgraph(name); + if (sub_graph == nullptr) { + REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); + GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); + continue; + } + cur_node_subgraph.emplace_back(sub_graph); + } + return cur_node_subgraph; +} + +graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set &changed_nodes) { + // if infer again, update output of while into subgraph data node + auto op_desc = node->GetOpDesc(); + for (const auto &sub_graph : GetCurNodeSubgraphs(node)) { + for (const auto &node_sub : sub_graph->GetDirectNode()) { + if (node_sub->GetType() != DATA) { + continue; + } + auto name = sub_graph->GetName(); + int ref_i; + auto data_opdesc = node_sub->GetOpDesc(); + if (data_opdesc == nullptr) { + REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), + node->GetName().c_str()); + GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), + node->GetName().c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute", + name.c_str(), node->GetName().c_str()); + GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), + node->GetName().c_str()); + return GRAPH_FAILED; + } + // In multi-batch, data shape of subgraph is different, no need to refresh. + if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + continue; + } + auto input_desc = op_desc->MutableInputDesc(ref_i); + if (input_desc == nullptr) { + REPORT_INNER_ERROR("E19999", + "The ref index(%d) on the data %s on the sub graph %s " + "parent node %s are incompatible, inputs num %u", + ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), + node->GetAllInDataAnchorsSize()); + GE_LOGE( + "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s " + "parent node %s are incompatible, inputs num %u", + ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize()); + return GRAPH_FAILED; + } + GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), + node->GetName().c_str()); + + auto data_input_desc = data_opdesc->MutableInputDesc(0); + if (!SameTensorDesc(input_desc, data_input_desc)) { + changed_nodes.insert(node_sub); + // if need infer again, refresh while subgraph input with while output + if (node->GetType() == WHILE) { + input_desc = op_desc->MutableOutputDesc(ref_i); + } + } + + auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); + if (ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s", + node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); + GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(), + name.c_str(), node->GetName().c_str()); + return ret; + } + + ret = data_opdesc->UpdateOutputDesc(0, *input_desc); + if (ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s", + node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); + GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed", + node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); + return ret; + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set &changed_nodes) { + std::vector> ref_data_tensors(node->GetAllInDataAnchorsSize()); + std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); + + for (const auto &sub_graph : GetCurNodeSubgraphs(node)) { + auto name = sub_graph->GetName(); + NodePtr netoutput = nullptr; + auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); + if (ret != GRAPH_SUCCESS) { + return ret; + } + if (netoutput == nullptr) { + REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(), + node->GetName().c_str()); + GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(), + node->GetName().c_str()); + return GRAPH_FAILED; + } + auto netoutput_opdesc = netoutput->GetOpDesc(); + if (netoutput_opdesc == nullptr) { + REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", + name.c_str(), node->GetName().c_str()); + GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), + node->GetName().c_str()); + return GRAPH_FAILED; + } + for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { + auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); + if (edge_desc == nullptr) { + REPORT_INNER_ERROR("E19999", + "Invalid NetOutput node on sub graph %s, parent node %s, " + "can not find input tensor %d", + name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); + GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", + name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); + return GRAPH_FAILED; + } + GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(), + edge_desc->GetShape().GetDimNum()); + int ref_i; + if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. + continue; + } + GELOGI("Parent node index of edge desc is %d", ref_i); + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { + return GRAPH_FAILED; + } + ref_out_tensors[ref_i].emplace_back(*edge_desc); + } + } + + if (node->GetType() == WHILE) { + return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes); + } + return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes); +} + +graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node, + std::vector> &ref_data_tensors, + std::vector> &ref_out_tensors, + std::set &changed_nodes) { + GELOGD("Enter update parent node shape for class while op process"); + if (ref_data_tensors.size() != ref_out_tensors.size()) { + REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!", + node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(), + ref_out_tensors.size()); + GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!", + node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size()); + return GRAPH_FAILED; + } + + // check input and output + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].size() != 1) { + REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!"); + GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!"); + return GRAPH_FAILED; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + for (auto &tensor : ref_data_tensors[i]) { + // ref_i's data and output tensor shape should be same + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output", + node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.", + node->GetName().c_str()); + return GRAPH_FAILED; + } + auto data_shape = tensor.MutableShape(); + auto out_shape = ref_out_tensor.MutableShape(); + if (data_shape.GetDims() != out_shape.GetDims()) { + GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.", + node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str()); + if (data_shape.GetDimNum() != out_shape.GetDimNum()) { + ref_out_tensor.SetUnknownDimNumShape(); + } else { + size_t data_dim_num = data_shape.GetDimNum(); + std::vector> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)}; + for (size_t j = 0; j < data_dim_num; ++j) { + if (data_shape.GetDim(j) != out_shape.GetDim(j)) { + data_shape.SetDim(j, UNKNOWN_DIM); + } + if (data_shape.GetDim(j) != UNKNOWN_DIM) { + data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)); + } + } + ref_out_tensor.SetShape(data_shape); + ref_out_tensor.SetShapeRange(data_shape_range); + } + } + } + + auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + bool output_changed = SameTensorDesc(ComGraphMakeShared(ref_out_tensor), output_desc); + if (output_changed) { + changed_nodes.insert(node); + } + } + return GRAPH_SUCCESS; +} + +graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node, + std::vector> &ref_out_tensors, + std::set &changed_nodes) { + // check sub_graph shape. Get max for update. + for (size_t i = 0; i < ref_out_tensors.size(); ++i) { + if (ref_out_tensors[i].empty()) { + continue; + } + + int64_t max_size = 0; + size_t max_shape_index = 0; + auto &ref_out_tensor = ref_out_tensors[i].at(0); + for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { + auto &tensor = ref_out_tensors[i].at(j); + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output", + node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output", + node->GetName().c_str()); + return GRAPH_FAILED; + } + + auto shape = tensor.MutableShape(); + int64_t size = 1; + for (auto dim : shape.GetDims()) { + if (dim != 0 && INT64_MAX / dim < size) { + REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(), + node->GetName().c_str()); + GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); + return PARAM_INVALID; + } + size *= dim; + } + + if (size > max_size) { + max_size = size; + max_shape_index = j; + } + } + + auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); + bool output_changed = + SameTensorDesc(ComGraphMakeShared(ref_out_tensors[i].at(max_shape_index)), output_desc); + if (output_changed) { + changed_nodes.insert(node); + } + } + + return GRAPH_SUCCESS; +} + +graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node, + std::vector> &ref_out_tensors, + std::set &changed_nodes) { + GELOGD("Enter update parent node shape for class branch op process"); + if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { + return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes); + } + + // check sub_graph shape.If not same ,do unknown shape process + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); + for (auto &tensor : ref_out_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s", + node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, + shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; + } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), + i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + + auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + bool output_changed = + SameTensorDesc(ComGraphMakeShared(ref_out_tensor), output_desc); + if (output_changed) { + changed_nodes.insert(node); + } + } + return GRAPH_SUCCESS; +} + +void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) { + if (!IsLogEnable(GE, DLOG_DEBUG)) { + return; + } + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid"); + GELOGE(GRAPH_FAILED, "[Check][Param] node is null"); + return; + } + ge::OpDescPtr op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid"); + GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return ); + std::stringstream ss; + ss << "{"; + int32_t in_idx = 0; + int32_t out_idx = 0; + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + if (input_desc == nullptr) { + in_idx++; + continue; + } + if (in_idx > 0) { + ss << " "; + } + ss << "input_" << in_idx << " " + << "tensor: ["; + ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; + ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; + ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; + ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; + ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; + ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; + string range_str; + SerialShapeRange(input_desc, range_str); + ss << "(shape_range:" << range_str << "),"; + std::vector> value_range; + (void)input_desc->GetValueRange(value_range); + string value_range_str = formats::RangeToString(value_range); + ss << "(value_range:" << value_range_str << ")]"; + in_idx++; + } + for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { + if (output_desc == nullptr) { + out_idx++; + continue; + } + ss << " "; + ss << "output_" << out_idx << " " + << "tensor: ["; + ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),"; + ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),"; + ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),"; + ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),"; + ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),"; + ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),"; + string range_str; + SerialShapeRange(output_desc, range_str); + ss << "(shape_range:" << range_str << "),"; + std::vector> value_range; + (void)output_desc->GetValueRange(value_range); + string value_range_str = formats::RangeToString(value_range); + ss << "(value_range:" << value_range_str << ")]"; + out_idx++; + } + ss << "}"; + GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str()); +} +} // namespace ge diff --git a/ge/graph/passes/infer_base_pass.h b/ge/graph/passes/infer_base_pass.h new file mode 100644 index 00000000..efa3c8a2 --- /dev/null +++ b/ge/graph/passes/infer_base_pass.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GE_GRAPH_PASSES_INFER_BASE_PASS_H_ +#define GE_GRAPH_PASSES_INFER_BASE_PASS_H_ + +#include "graph/passes/base_pass.h" + +namespace ge { +class InferBasePass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + graphStatus InferAndUpdate(NodePtr &node, bool before_subgraph, std::set &changed_nodes); + void PrintInOutTensorShape(const NodePtr &node, const std::string &phase); + + protected: + virtual bool NeedInfer(const NodePtr &node); + virtual graphStatus Infer(NodePtr &node) = 0; + virtual bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) = 0; + virtual graphStatus UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0; + + private: + void AddChangedNodesImmediateRepass(const std::set &changed_nodes); + bool ContainsSubgraph(const NodePtr &node); + std::vector GetCurNodeSubgraphs(const NodePtr &node); + graphStatus UpdateTensorDescToSubgraphData(NodePtr &node, std::set &changed_nodes); + graphStatus UpdateTensorDescToParentNode(NodePtr &node, std::set &changed_nodes); + graphStatus UpdateParentNodeForWhile(NodePtr &node, std::vector> &ref_data_tensors, + std::vector> &ref_out_tensors, + std::set &changed_nodes); + graphStatus UpdateParentNodeForBranch(NodePtr &node, std::vector> &ref_out_tensors, + std::set &changed_nodes); + graphStatus UpdateOutputForMultiBatch(NodePtr &node, std::vector> &ref_out_tensors, + std::set &changed_nodes); + graphStatus UpdateTensorDescToPeerInputs(NodePtr &node, std::set &changed_nodes); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INFER_BASE_PASS_H_ diff --git a/ge/graph/passes/infer_value_range_pass.cc b/ge/graph/passes/infer_value_range_pass.cc new file mode 100644 index 00000000..377b7d34 --- /dev/null +++ b/ge/graph/passes/infer_value_range_pass.cc @@ -0,0 +1,383 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/infer_value_range_pass.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/operator_factory_impl.h" +#include "graph/passes/constant_folding_pass.h" +#include "graph/utils/type_utils.h" +#include "common/ge/ge_util.h" + +using std::unique_ptr; +namespace ge { +namespace { +#define GET_DATA_BY_DTYPE(DTYPE, TYPE) \ + case (DTYPE): \ + ConstructValueRange(lower_tensor, higher_tensor, output_tensor_value_range); \ + break; + +Status RunCpuKernelForValueRange(NodePtr &node, const vector &inputs, + std::vector &outputs) { + // should use RunOpKernelWithCheck, RunOpKernel for ut test + auto ret = ConstantFoldingPass::RunOpKernel(node, inputs, outputs); + if (ret != SUCCESS) { + auto op_kernel = folding_pass::GetKernelByType(node); + if (op_kernel == nullptr) { + GELOGW("Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(), + node->GetType().c_str()); + return NOT_CHANGED; + } + + ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs); + if (ret != SUCCESS) { + GELOGW("Calculate for node %s failed in constant folding", node->GetName().c_str()); + return NOT_CHANGED; + } + } + GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str()); + return SUCCESS; +} +} // namespace + +graphStatus InferValueRangePass::Infer(NodePtr &node) { + PrintInOutTensorShape(node, "before_infer_value_range"); + auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); + + // Use registered func to calculate value range + if (!infer_value_range_param.use_cpu_kernel) { + if (infer_value_range_param.infer_value_func == nullptr) { + GELOGW("The registered func of node %s to infer value range is nullptr.", node->GetName().c_str()); + return GRAPH_NOT_CHANGED; + } + Operator op = OpDescUtils::CreateOperatorFromNode(node); + auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op); + if (ret != GRAPH_SUCCESS) { + GELOGW("Node %s call infer value range func failed, ret: %u.", node->GetName().c_str(), ret); + return GRAPH_NOT_CHANGED; + } + return GRAPH_SUCCESS; + } + + // Use CPU kernel func to calculate value range + auto ret = ConstructInputAndInferValueRange(node); + if (ret != GRAPH_SUCCESS) { + GELOGW("Use CPU kernel to calculate value range failed. node: %s, ret: %u", node->GetName().c_str(), ret); + return GRAPH_NOT_CHANGED; + } + return GRAPH_SUCCESS; +} + +bool InferValueRangePass::NeedInfer(const NodePtr &node) { + auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); + if (!infer_value_range_param.is_initialized) { + GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.", + node->GetName().c_str()); + return false; + } + + if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) { + // Only do infer for node that all inputs are dynamic, such as shape + if (InputIsDynamic(node)) { + return true; + } + GELOGD("Node %s register func to infer value range and when_call is INPUT_IS_DYNAMIC, but check input failed.", + node->GetName().c_str()); + } else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) { + // Only do infer for node that all inputs have value_range or node type of inputs is constant/const + if (InputIsConstOrHasValueRange(node)) { + return true; + } + GELOGD("Node %s register func to infer value range and when_call is INPUT_HAS_VALUE_RANGE, but check input failed.", + node->GetName().c_str()); + } + GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str()); + return false; +} + +bool InferValueRangePass::InputIsDynamic(const NodePtr &node) { + bool input_is_dynamic = false; + auto cur_op_desc = node->GetOpDesc(); + for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { + auto dims = input_desc->GetShape().GetDims(); + for (auto dim : dims) { + if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { + input_is_dynamic = true; + break; + } + } + } + return input_is_dynamic; +} + +bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) { + bool input_is_const_or_has_value_range = true; + auto cur_op_desc = node->GetOpDesc(); + auto in_data_anchors = node->GetAllInDataAnchors(); + for (size_t i = 0; i < in_data_anchors.size(); ++i) { + auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + auto peer_node = peer_out_anchor->GetOwnerNode(); + if (peer_node == nullptr || peer_node->GetOpDesc() == nullptr) { + continue; + } + if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { + continue; + } + + const auto &input_desc = cur_op_desc->GetInputDesc(i); + std::vector> value_range; + (void)input_desc.GetValueRange(value_range); + if (value_range.empty()) { + GELOGD("Node %s input %zu does not have value range, skip infer_value_range_pass for current node.", + node->GetName().c_str(), i); + input_is_const_or_has_value_range = false; + break; + } + } + return input_is_const_or_has_value_range; +} + + +bool InferValueRangePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { + bool same_desc = true; + std::vector> src_value_range; + std::vector> dst_value_range; + (void)src->GetValueRange(src_value_range); + (void)dst->GetValueRange(dst_value_range); + if (src_value_range != dst_value_range) { + same_desc = false; + } + return same_desc; +} + +graphStatus InferValueRangePass::UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { + changed = false; + std::vector> src_value_range; + std::vector> dst_value_range; + (void)src->GetValueRange(src_value_range); + (void)dst->GetValueRange(dst_value_range); + if (src_value_range != dst_value_range) { + changed = true; + } + + dst->SetValueRange(src_value_range); + return GRAPH_SUCCESS; +} + +template +graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, + GeTensorPtr &output_ptr) { + std::vector> value_range; + (void)tensor_desc.GetValueRange(value_range); + if (static_cast(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) { + REPORT_INNER_ERROR("E19999", "Value range of input %s is invalid.", tensor_desc.GetName().c_str()); + GELOGE(GRAPH_PARAM_INVALID, "Value range of input %s is invalid.", tensor_desc.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + + size_t value_range_data_num = value_range.size(); + unique_ptr buf(new (std::nothrow) T[value_range_data_num]()); + if (buf == nullptr) { + REPORT_INNER_ERROR("E19999", "New buf failed"); + GELOGE(MEMALLOC_FAILED, "new buf failed"); + return GRAPH_FAILED; + } + for (size_t j = 0; j < value_range_data_num; ++j) { + auto value_range_j = use_floor_value ? value_range[j].first : value_range[j].second; + buf[j] = static_cast(value_range_j); + } + + if (output_ptr->SetData(reinterpret_cast(buf.get()), value_range_data_num * sizeof(T)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "set data failed"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus InferValueRangePass::ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, + GeTensorPtr &output_ptr) { + graphStatus ret = GRAPH_SUCCESS; + auto data_type = tensor_desc.GetDataType(); + output_ptr->MutableTensorDesc().SetDataType(data_type); + switch (data_type) { + case DT_FLOAT: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_DOUBLE: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_UINT8: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_INT8: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_UINT16: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_INT16: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_INT32: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_INT64: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + default: + GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); + ret = GRAPH_FAILED; + } + return ret; +} + +vector InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) { + vector input_tensors; + auto cur_op_desc = node->GetOpDesc(); + auto in_data_anchors = node->GetAllInDataAnchors(); + for (size_t i = 0; i < in_data_anchors.size(); ++i) { + auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + auto peer_node = peer_out_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; + } + + // construct input tensor by constant node + if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { + vector const_weight = OpDescUtils::MutableWeights(peer_node); + if (const_weight.empty()) { + REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight is empty, node: %s(%s)", + peer_node->GetName().c_str(), peer_node->GetType().c_str()); + GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(), + peer_node->GetType().c_str()); + return vector(); + } + // const/constant op has only one weight + if (const_weight.at(0) == nullptr) { + REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight of constant is null, node: %s(%s)", + peer_node->GetName().c_str(), peer_node->GetType().c_str()); + GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight of constant is null, node name: %s(%s)", + peer_node->GetName().c_str(), peer_node->GetType().c_str()); + return vector(); + } + input_tensors.push_back(const_weight.at(0)); + continue; + } + + // construct input tensor by boundary of value range + const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i); + GeTensorPtr tmp_tensor_ptr = MakeShared(input_tensor_desc); + if (tmp_tensor_ptr == nullptr) { + REPORT_INNER_ERROR("E19999", "Make shared failed"); + GELOGE(MEMALLOC_FAILED, "Make shared failed"); + return vector(); + } + + auto ret = ConstructDataByType(input_tensor_desc, use_floor_value, tmp_tensor_ptr); + if (ret != GRAPH_SUCCESS) { + REPORT_INNER_ERROR("E19999", "Input %s construct input tensor by boundary of value range failed.", + input_tensor_desc.GetName().c_str()); + GELOGE(GRAPH_PARAM_INVALID, "Input %s construct input tensor by boundary of value range failed.", + input_tensor_desc.GetName().c_str()); + return vector(); + } + input_tensors.push_back(tmp_tensor_ptr); + } + + return input_tensors; +} + +graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) { + auto inputs = ConstructInputTensors(node, true); + if (inputs.empty()) { + return GRAPH_PARAM_INVALID; + } + vector outputs_lower; + auto ret = RunCpuKernelForValueRange(node, inputs, outputs_lower); + if (ret != SUCCESS) { + REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); + GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str()); + return GRAPH_FAILED; + } + + inputs = ConstructInputTensors(node, false); + if (inputs.empty()) { + return GRAPH_PARAM_INVALID; + } + vector outputs_higher; + ret = RunCpuKernelForValueRange(node, inputs, outputs_higher); + if (ret != SUCCESS) { + REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); + GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str()); + return GRAPH_FAILED; + } + + // construct value range from output tensor + OpDescPtr node_desc = node->GetOpDesc(); + std::vector> output_tensor_value_range; + size_t node_output_desc_size = node_desc->GetOutputsSize(); + for (size_t i = 0; i < node_output_desc_size; ++i) { + output_tensor_value_range.clear(); + auto lower_tensor = outputs_lower[i]; + auto lower_tensor_shape_size = lower_tensor->GetTensorDesc().GetShape().GetShapeSize(); + auto higher_tensor = outputs_higher[i]; + auto higher_tensor_shape_size = higher_tensor->GetTensorDesc().GetShape().GetShapeSize(); + auto output_tensor_desc = node_desc->MutableOutputDesc(i); + auto output_tensor_shape_size = output_tensor_desc->GetShape().GetShapeSize(); + if (output_tensor_shape_size != lower_tensor_shape_size || output_tensor_shape_size != higher_tensor_shape_size) { + GELOGE(GRAPH_PARAM_INVALID, "Value range of output %s is invalid.", output_tensor_desc->GetName().c_str()); + } + + auto data_type = output_tensor_desc->GetDataType(); + switch (data_type) { + GET_DATA_BY_DTYPE(DT_INT8, int8_t) + GET_DATA_BY_DTYPE(DT_INT16, int16_t) + GET_DATA_BY_DTYPE(DT_INT32, int32_t) + GET_DATA_BY_DTYPE(DT_INT64, int64_t) + GET_DATA_BY_DTYPE(DT_UINT8, uint8_t) + GET_DATA_BY_DTYPE(DT_UINT16, uint16_t) + GET_DATA_BY_DTYPE(DT_UINT32, uint32_t) + GET_DATA_BY_DTYPE(DT_UINT64, uint64_t) + GET_DATA_BY_DTYPE(DT_FLOAT, float) + GET_DATA_BY_DTYPE(DT_DOUBLE, double) + default: + GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); + return GRAPH_FAILED; + } + output_tensor_desc->SetValueRange(output_tensor_value_range); + } + return GRAPH_SUCCESS; +} + +template +void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, + std::vector> &value_range) { + auto x = reinterpret_cast(left_tensor->GetData().GetData()); + auto y = reinterpret_cast(right_tensor->GetData().GetData()); + for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) { + auto left = static_cast(*(x + j)); + auto right = static_cast(*(y + j)); + value_range.emplace_back(std::make_pair(left, right)); + } +} +} // namespace ge diff --git a/ge/graph/passes/infer_value_range_pass.h b/ge/graph/passes/infer_value_range_pass.h new file mode 100644 index 00000000..8f9be18b --- /dev/null +++ b/ge/graph/passes/infer_value_range_pass.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ +#define GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ + +#include "graph/passes/infer_base_pass.h" + +namespace ge { +class InferValueRangePass : public InferBasePass { + public: + graphStatus Infer(NodePtr &node) override; + + protected: + bool NeedInfer(const NodePtr &node) override; + bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) override; + graphStatus UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + + private: + bool InputIsDynamic(const NodePtr &node); + bool InputIsConstOrHasValueRange(const NodePtr &node); + template + graphStatus ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); + graphStatus ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); + vector ConstructInputTensors(const NodePtr &node, bool use_floor_value); + template + void ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, + std::vector> &value_range); + graphStatus ConstructInputAndInferValueRange(NodePtr &node); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 2eae6023..b8062fb6 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -54,6 +54,7 @@ #include "graph/passes/hccl_group_pass.h" #include "graph/passes/identity_pass.h" #include "graph/passes/infershape_pass.h" +#include "graph/passes/infer_value_range_pass.h" #include "graph/passes/merge_pass.h" #include "graph/passes/net_output_pass.h" #include "graph/passes/no_use_reshape_remove_pass.h" @@ -1989,6 +1990,8 @@ Status GraphPrepare::InferShapeForPreprocess() { names_to_passes.emplace_back("MergePass", &merge_pass); InferShapePass infer_shape_pass; names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); + InferValueRangePass infer_value_pass; + names_to_passes.emplace_back("InferValuePass", &infer_value_pass); ReplaceWithEmptyConstPass replace_with_empty_const_pass; names_to_passes.emplace_back("ReplaceWithEmptyConstPass", &replace_with_empty_const_pass); DimensionComputePass dimension_compute_pass; diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 5bff0f98..e81c2a76 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -219,7 +219,9 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" @@ -532,7 +534,9 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" @@ -708,6 +712,7 @@ set(PASS_TEST_FILES "graph/passes/net_output_pass_unittest.cc" "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" + "graph/passes/infer_value_range_pass_unittest.cc" "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" "graph/passes/replace_with_empty_const_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc new file mode 100644 index 00000000..b16b8971 --- /dev/null +++ b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc @@ -0,0 +1,816 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "graph/passes/infer_value_range_pass.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph_builder_utils.h" + +#include "inc/external/graph/operator_reg.h" +#include "inc/external/graph/operator.h" +#include "inc/external/graph/operator_factory.h" +#include "inc/graph/operator_factory_impl.h" +#include "inc/kernel.h" +#include "inc/kernel_factory.h" + +using namespace std; +using namespace testing; +namespace ge { +class UtestGraphInferValueRangePass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +/* + * data1 const1 + * \ / + * case1 + * | + * relu10 + * | + * netoutput + */ +ut::GraphBuilder ParentGraphBuilder() { + ut::GraphBuilder builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + std::vector const_shape = {1}; + auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_INT32, const_shape); + auto case1 = builder.AddNode("case1", CASE, 2, 1); + auto relu1 = builder.AddNode("relu10", "Relu", 1, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + int32_t weight[1] = {1}; + GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); + GeTensorPtr tensor = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); + OpDescUtils::SetWeights(const1, {tensor}); + auto case_in0_shape = GeShape({1,1,-1,224}); + case1->GetOpDesc()->MutableInputDesc(0)->SetShape(case_in0_shape); + std::vector> in_range = {make_pair(1, 1), make_pair(1, 1), + make_pair(1, -1), make_pair(1, 224)}; + case1->GetOpDesc()->MutableInputDesc(0)->SetValueRange(in_range); + auto case_in1_shape = GeShape({1,1}); + case1->GetOpDesc()->MutableInputDesc(1)->SetShape(case_in1_shape); + + builder.AddDataEdge(data1, 0, case1, 0); + builder.AddDataEdge(const1, 0, case1, 1); + builder.AddDataEdge(case1, 0, relu1, 0); + builder.AddDataEdge(relu1, 0, netoutput, 0); + return builder; +} + +/* + * data1 data2 + * \ / + * switch + * / \ + * relu1 relu2 + * \ / + * merge + * | + * netoutput + */ +ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) { + ut::GraphBuilder builder = ut::GraphBuilder(graph_name); + + std::vector shape1 = {2,2}; + string data1_name = "data1_" + std::to_string(num); + auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1); + auto data1_desc = data1->GetOpDesc(); + EXPECT_NE(data1_desc, nullptr); + AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); + + std::vector shape2 = {3,3}; + string data2_name = "data2_" + std::to_string(num); + auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2); + auto data2_desc = data2->GetOpDesc(); + EXPECT_NE(data2_desc, nullptr); + AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); + + string switch_name = "switch_" + std::to_string(num); + auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); + + string relu1_name = "relu1_" + std::to_string(num); + auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1); + + string relu2_name = "relu2_" + std::to_string(num); + auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1); + + string merge_name = "merge_" + std::to_string(num); + auto merge = builder.AddNode(merge_name, "Merge", 2, 1); + + std::vector shape7 = {8,8}; + string output_name = "output_" + std::to_string(num); + auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7); + auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); + EXPECT_NE(input0_desc, nullptr); + AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); + std::vector> range = {make_pair(1, -1), make_pair(1, -1)}; + input0_desc->SetValueRange(range); + + builder.AddDataEdge(data1, 0, switch1, 0); + builder.AddDataEdge(data2, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, relu1, 0); + builder.AddDataEdge(switch1, 1, relu2, 0); + builder.AddDataEdge(relu1, 0, merge, 0); + builder.AddDataEdge(relu2, 0, merge, 1); + builder.AddDataEdge(merge, 0, netoutput, 0); + + return builder; +} + +void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + + for (uint32_t i = 0; i < branch_num; ++i) { + string name = "Branch_Graph_" + std::to_string(i); + + auto builder_subgraph = SwitchSubgraphBuilder(name, i); + auto switch_subgraph = builder_subgraph.GetGraph(); + + case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); + case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); + + switch_subgraph->SetParentNode(case_node); + switch_subgraph->SetParentGraph(parent_graph); + EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); + } +} + +TEST_F(UtestGraphInferValueRangePass, infer_pass_not_register) { + auto graph = std::make_shared("test_graph"); + GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_FLOAT16); + auto addn_op_desc = std::make_shared("AddN", "AddN"); + addn_op_desc->AddInputDesc(ge_tensor_desc); + addn_op_desc->AddOutputDesc(ge_tensor_desc); + auto addn_op_node = graph->AddNode(addn_op_desc); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(addn_op_node), SUCCESS); +} + +auto ShapeValueInfer = [&](Operator &op) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto output_tensor_desc = op_desc->MutableOutputDesc(0); + std::vector> in_shape_range; + op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); + if (!in_shape_range.empty()) { + output_tensor_desc->SetValueRange(in_shape_range); + } + return SUCCESS; +}; +REG_OP(Shape) + .OP_END_FACTORY_REG(Shape) +IMPL_INFER_VALUE_RANGE_FUNC(Shape, ShapeValueRangeFunc){ + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto output_tensor_desc = op_desc->MutableOutputDesc(0); + std::vector> in_shape_range; + op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); + if (!in_shape_range.empty()) { + output_tensor_desc->SetValueRange(in_shape_range); + } + return GRAPH_SUCCESS; +} + +TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_1_not_infer) { + INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueRangeFunc); + auto graph = std::make_shared("test_graph"); + GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); + std::vector> shape_range = {make_pair(1, 1), make_pair(1, 1), + make_pair(4, 4), make_pair(192, 192)}; + ge_tensor_desc.SetShapeRange(shape_range); + GeTensorDesc output_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); + auto op_desc = std::make_shared("Shape", "Shape"); + op_desc->AddInputDesc(ge_tensor_desc); + op_desc->AddOutputDesc(output_tensor_desc); + auto op_node = graph->AddNode(op_desc); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(op_node), SUCCESS); + + auto output_0_desc = op_node->GetOpDesc()->GetOutputDesc(0); + std::vector> value_range; + output_0_desc.GetValueRange(value_range); + EXPECT_EQ(value_range.empty(), true); +} + +TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_1_infer) { + // sqrt -> shape -> Output + INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueRangeFunc); + auto graph = std::make_shared("test_graph"); + GeTensorDesc sqrt_tensor_desc(GeShape({-1, -1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); + std::vector> shape_range = {make_pair(1, 100), make_pair(1, 240), + make_pair(4, 4), make_pair(192, 192)}; + sqrt_tensor_desc.SetShapeRange(shape_range); + auto sqrt_op_desc = std::make_shared("Sqrt", "Sqrt"); + sqrt_op_desc->AddInputDesc(sqrt_tensor_desc); + sqrt_op_desc->AddOutputDesc(sqrt_tensor_desc); + auto sqrt_node = graph->AddNode(sqrt_op_desc); + + GeTensorDesc shape_output_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddInputDesc(sqrt_tensor_desc); + shape_op_desc->AddOutputDesc(shape_output_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc Output_in_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); + auto Output_op_desc = std::make_shared("Output", "Output"); + Output_op_desc->AddInputDesc(Output_in_tensor_desc); + auto Output_node = graph->AddNode(Output_op_desc); + + ge::GraphUtils::AddEdge(sqrt_node->GetOutDataAnchor(0), shape_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), Output_node->GetInDataAnchor(0)); + EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); + + + InferValueRangePass infer_pass; + auto ret = infer_pass.Run(shape_node); + EXPECT_EQ(ret, SUCCESS); + + auto output_0_desc = shape_node->GetOpDesc()->GetOutputDesc(0); + std::vector> value_range; + output_0_desc.GetValueRange(value_range); + EXPECT_EQ(value_range.size(), 4); + std::vector target_value_range = {1, 100, 1, 240, 4, 4, 192, 192}; + std::vector output_value_range; + for (auto pair : value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range); + + auto in_0_desc = Output_node->GetOpDesc()->GetInputDesc(0); + value_range.clear(); + in_0_desc.GetValueRange(value_range); + EXPECT_EQ(value_range.size(), 4); + output_value_range.clear(); + for (auto pair : value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range); + +} + +class AddKernel : public Kernel { + public: + Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input, + std::vector &v_output) override { + vector data_vec; + auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); + auto x1_data = reinterpret_cast(input[0]->GetData().data()); + auto x2_data = reinterpret_cast(input[1]->GetData().data()); + for (size_t i = 0; i < data_num; i++) { + auto x_index = *(x1_data + i); + auto y_index = *(x2_data + i); + data_vec.push_back(x_index + y_index); + } + GeTensorPtr const_tensor = std::make_shared(input[0]->GetTensorDesc(), (uint8_t *)data_vec.data(), + data_num * sizeof(int64_t)); + v_output.emplace_back(const_tensor); + return SUCCESS; + } +}; +REGISTER_KERNEL(ADD, AddKernel); + +TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_2_infer) { + // shape --- add --- sqrt + // constant / + INFER_VALUE_RANGE_DEFAULT_REG(Add); + INFER_VALUE_RANGE_DEFAULT_REG("Sqrt"); + auto graph = std::make_shared("test_graph"); + + vector dims_vec = {4}; + vector data_vec = {1, 1, 1, 1}; + GeTensorDesc const_tensor_desc(ge::GeShape(dims_vec), ge::FORMAT_NCHW, ge::DT_INT64); + GeTensorPtr const_tensor = + std::make_shared(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t)); + + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_tensor_desc); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); + std::vector> value_range = {make_pair(1, 100), make_pair(1, 240), + make_pair(4, 4), make_pair(192, 192)}; + shape_tensor_desc.SetValueRange(value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_tensor_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddInputDesc(const_tensor_desc); + add_op_desc->AddOutputDesc(add_tensor_desc); + auto add_node = graph->AddNode(add_op_desc); + + auto sqrt_op_desc = std::make_shared("Sqrt", "Sqrt"); + sqrt_op_desc->AddInputDesc(GeTensorDesc()); + auto sqrt_node = graph->AddNode(sqrt_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + ge::GraphUtils::AddEdge(add_node->GetOutDataAnchor(0), sqrt_node->GetInDataAnchor(1)); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(sqrt_node), SUCCESS); + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 4); + + std::vector target_value_range = {2, 101, 2, 241, 5, 5, 193, 193}; + std::vector output_value_range; + for (auto pair : out_value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range); +} + +REG_OP(Case) + .OP_END_FACTORY_REG(Case) +IMPL_INFER_VALUE_RANGE_FUNC(Case, ValueRangeFunc){ + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto output_tensor_desc = op_desc->MutableOutputDesc(0); + std::vector> in_shape_range = {make_pair(1, 2), make_pair(1, 3), + make_pair(1, 4), make_pair(1, 5)};; + output_tensor_desc->SetValueRange(in_shape_range); + return GRAPH_SUCCESS; +} +TEST_F(UtestGraphInferValueRangePass, infer_with_case_subgraph) { + auto builder = ParentGraphBuilder(); + auto parent_graph = builder.GetGraph(); + AddCaseSubgraph(parent_graph, 2); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 2); + + // check before subgraph + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Case, INPUT_HAS_VALUE_RANGE, ValueRangeFunc); + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(case_node), SUCCESS); + + auto case_out_0_desc = case_node->GetOpDesc()->MutableOutputDesc(0); + std::vector> out_value_range; + case_out_0_desc->GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 4); + std::vector target_value_range = {1,2,1,3,1,4,1,5}; + std::vector output_value_range_list; + for (auto pair : out_value_range) { + output_value_range_list.push_back(pair.first); + output_value_range_list.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range_list); + + std::vector target_dims_0 = {1, 1, -1, 224}; + std::vector target_dims_1 = {1,1}; + auto data_node = subgraphs[0]->FindNode("data1_0"); + auto data_output_0_desc = data_node->GetOpDesc()->GetOutputDesc(0); + EXPECT_EQ(target_dims_0, data_output_0_desc.GetShape().GetDims()); + data_node = subgraphs[0]->FindNode("data2_0"); + auto data2_output_0_desc = data_node->GetOpDesc()->GetOutputDesc(0); + EXPECT_EQ(target_dims_1, data2_output_0_desc.GetShape().GetDims()); + + // check after subgraph + infer_pass.options_[kOptimizeAfterSubGraph] = "yes"; + EXPECT_EQ(infer_pass.Run(case_node), SUCCESS); + + std::vector out_target_dims = {1, -1, 1, -1}; + auto case_out = case_node->GetOpDesc()->GetOutputDescPtr(0); + out_value_range.clear(); + case_out->GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 2); + + output_value_range_list.clear(); + for (auto pair : out_value_range) { + output_value_range_list.push_back(pair.first); + output_value_range_list.push_back(pair.second); + } + EXPECT_EQ(out_target_dims, output_value_range_list); +} + +/* + * data1 const1 + * \ / + * while + * / \ + * relu1 netoutput + */ +ut::GraphBuilder ParentWhileGraphBuilder() { + ut::GraphBuilder builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + std::vector const_shape = {1}; + auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_FLOAT, const_shape); + auto while1 = builder.AddNode("while1", WHILE, 2, 2); + auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + int32_t weight[1] = {1}; + GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_FLOAT); + GeTensorPtr tensor = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); + OpDescUtils::SetWeights(const1, {tensor}); + std::vector> in_range = {make_pair(1, 1), make_pair(1, 1), + make_pair(1, 224), make_pair(1, 224)}; + while1->GetOpDesc()->MutableInputDesc(0)->SetValueRange(in_range); + + builder.AddDataEdge(data1, 0, while1, 0); + builder.AddDataEdge(const1, 0, while1, 1); + builder.AddDataEdge(while1, 0, relu1, 0); + builder.AddDataEdge(while1, 1, netoutput, 0); + return builder; +} + +/* + * data1 data2 + * \ / + * switch + * | | + * \ / + * netoutput + */ +ut::GraphBuilder WhileSubgraphBuilder(string graph_name, uint32_t num) { + ut::GraphBuilder builder = ut::GraphBuilder(graph_name); + + std::vector shape1 = {2,2}; + string data1_name = "data1_" + std::to_string(num); + auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_FLOAT, shape1); + auto data1_desc = data1->GetOpDesc(); + EXPECT_NE(data1_desc, nullptr); + AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); + + std::vector shape2 = {3,3}; + string data2_name = "data2_" + std::to_string(num); + auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_FLOAT, shape2); + auto data2_desc = data2->GetOpDesc(); + EXPECT_NE(data2_desc, nullptr); + AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); + + string switch_name = "switch_" + std::to_string(num); + auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); + + std::vector shape7 = {8,8,8,8}; + string output_name = "output_" + std::to_string(num); + auto netoutput = builder.AddNode(output_name, NETOUTPUT, 2, 0, FORMAT_NCHW, DT_FLOAT, shape7); + auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); + EXPECT_NE(input0_desc, nullptr); + AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); + std::vector> range0 = {make_pair(1, -1), make_pair(1, -1)}; + input0_desc->SetValueRange(range0); + auto input1_desc = netoutput->GetOpDesc()->MutableInputDesc(1); + EXPECT_NE(input1_desc, nullptr); + AttrUtils::SetInt(input1_desc, "_parent_node_index", 1); + std::vector> range1 = {make_pair(8, 8), make_pair(8, 8),make_pair(8, 8),make_pair(8, 8)}; + input1_desc->SetValueRange(range1); + + builder.AddDataEdge(data1, 0, switch1, 0); + builder.AddDataEdge(data2, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, netoutput, 0); + builder.AddDataEdge(switch1, 1, netoutput, 1); + return builder; +} + +void AddWhileSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { + auto while_node = parent_graph->FindNode("while1"); + EXPECT_NE(while_node, nullptr); + + for (uint32_t i = 0; i < branch_num; ++i) { + string name = "Branch_Graph_" + std::to_string(i); + + auto builder_subgraph = WhileSubgraphBuilder(name, i); + auto switch_subgraph = builder_subgraph.GetGraph(); + + while_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); + while_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); + switch_subgraph->SetParentNode(while_node); + switch_subgraph->SetParentGraph(parent_graph); + EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); + } +} + +REG_OP(While) + .OP_END_FACTORY_REG(While) +IMPL_INFER_VALUE_RANGE_FUNC(While, WhileValueRangeFunc){ + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + std::vector> in_range = {make_pair(1, 2), make_pair(1, 3), + make_pair(1, 4), make_pair(1, 5)};; + for (auto i =0; iGetOutputsSize();++i){ + auto output_tensor_desc = op_desc->MutableOutputDesc(i); + output_tensor_desc->SetValueRange(in_range); + } + return GRAPH_SUCCESS; +} +INFER_VALUE_RANGE_CUSTOM_FUNC_REG(While, INPUT_HAS_VALUE_RANGE, WhileValueRangeFunc); +TEST_F(UtestGraphInferValueRangePass, infer_with_while_subgraph) { + auto builder = ParentWhileGraphBuilder(); + auto parent_graph = builder.GetGraph(); + AddWhileSubgraph(parent_graph, 1); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 1); + + // check before subgraph + auto while_node = parent_graph->FindNode("while1"); + EXPECT_NE(while_node, nullptr); + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(while_node), SUCCESS); + + auto while_out_0_desc = while_node->GetOpDesc()->MutableOutputDesc(0); + std::vector> out_value_range; + while_out_0_desc->GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 4); + std::vector target_value_range = {1,2,1,3,1,4,1,5}; + std::vector output_value_range_list; + for (auto pair : out_value_range) { + output_value_range_list.push_back(pair.first); + output_value_range_list.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range_list); + + std::vector target_dims_0 = {1, 1, 224, 224}; + auto data_node = subgraphs[0]->FindNode("data1_0"); + auto data_input_0_desc = data_node->GetOpDesc()->GetInputDesc(0); + EXPECT_EQ(target_dims_0, data_input_0_desc.GetShape().GetDims()); + + // check after subgraph + infer_pass.options_[kOptimizeAfterSubGraph] = "yes"; + EXPECT_EQ(infer_pass.Run(while_node), SUCCESS); + + std::vector out_target_dims = {1, -1, 1, -1}; + auto while_out0 = while_node->GetOpDesc()->GetOutputDescPtr(0); + out_value_range.clear(); + while_out0->GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 2); + output_value_range_list.clear(); + for (auto pair : out_value_range) { + output_value_range_list.push_back(pair.first); + output_value_range_list.push_back(pair.second); + } + EXPECT_EQ(output_value_range_list, out_target_dims); + + std::vector out_target_dims_1 = {8,8, 8,8, 8,8, 8,8}; + auto while_out1 = while_node->GetOpDesc()->GetOutputDescPtr(1); + out_value_range.clear(); + while_out1->GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 4); + output_value_range_list.clear(); + for (auto pair : out_value_range) { + output_value_range_list.push_back(pair.first); + output_value_range_list.push_back(pair.second); + } + EXPECT_EQ(output_value_range_list, out_target_dims_1); +} + +TEST_F(UtestGraphInferValueRangePass, infer_with_while_subgraph_failed) { + auto builder = ParentWhileGraphBuilder(); + auto parent_graph = builder.GetGraph(); + AddWhileSubgraph(parent_graph, 2); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 2); + + auto case_node = parent_graph->FindNode("while1"); + EXPECT_NE(case_node, nullptr); + InferValueRangePass infer_pass; + infer_pass.options_[kOptimizeAfterSubGraph] = "yes"; + EXPECT_EQ(infer_pass.Run(case_node), GRAPH_FAILED); +} + + + +bool IsEmptyTensor(const GeShape &ge_shape) { + bool is_empty = false; + for (const auto &dim : ge_shape.GetDims()) { + if (dim == 0) { + is_empty = true; + break; + } + } + return is_empty; +} +bool IsEmptyTensor(GeTensorDescPtr tensor_desc) { + return IsEmptyTensor(tensor_desc->MutableShape()); +} +graphStatus ReshapeRangeInferAllDims(const std::vector> &x_shape_range, + const GeShape &x_shape, + const std::vector> &shape_value_range, + std::vector> &y_shape_range, GeShape &y_shape) { + // input_shape is not constant, can not get accurate shape value. + if (x_shape.GetDims() == UNKNOWN_RANK) { + return GRAPH_SUCCESS; + } + + // step 1, calculate input_x range max + int64_t range_max = 1; + auto x_shape_size = x_shape.GetShapeSize(); + if (x_shape_size > 0) { + // known dim, x_shape_size == range_max + range_max = x_shape_size; + } else { + // unknown dim + if (x_shape_range.empty()) { + return GRAPH_SUCCESS; + } + for (const auto &pair : x_shape_range) { + if (pair.second < 0) { + range_max = -1; + break; + } + range_max *= pair.second; + } + } + + // step 2, init y shape range + auto y_rank = y_shape.GetDims().size(); + auto shape_range_max = (range_max > INT32_MAX) ? INT32_MAX : range_max; + for (auto i = 0; i < y_rank; ++i) { + y_shape_range.emplace_back(std::pair(1, shape_range_max)); + } + if (shape_value_range.empty()) { + // no value range, can not calculate accurate shape range. + return GRAPH_SUCCESS; + } + + // step 2, repair value range and check zero in value range + bool has_zero_in_value_range = false; + std::vector> value_range = shape_value_range; + for (auto &pair : value_range) { + if (pair.first < 0) { + pair.first = 1; + } + if (pair.second < 0) { + pair.second = -1; + } + if (pair.first == 0) { + has_zero_in_value_range = true; + } + } + + // step 3, deal with empty tensor. if no value range cannot infer empty tensor. + if (IsEmptyTensor(x_shape)) { + if (range_max != 0) { + return GRAPH_FAILED; + } + if (!has_zero_in_value_range) { + return GRAPH_FAILED; + } + std::vector y_dims = y_shape.GetDims(); + for (auto i = 0; i < y_rank; ++i) { + if (value_range[i].first == value_range[i].second) { + y_dims[i] = value_range[i].first; + } + } + y_shape_range = value_range; + y_shape = GeShape(y_dims); + return GRAPH_SUCCESS; + } + + // step 4, calculate accurate dims and shape_range + std::vector y_dims = y_shape.GetDims(); + for (auto i = 0; i < y_rank; ++i) { + if (value_range[i].first == value_range[i].second) { + y_dims[i] = value_range[i].first; + y_shape_range[i] = std::pair(y_dims[i], y_dims[i]); + } else { + if (range_max == -1) { + // while range_max = -1, range_max && value_range[i].second is always value_range[i].second; + y_shape_range[i] = std::pair(value_range[i].first, value_range[i].second); + continue; + } + int64_t other_dims_range_lower_boundary = 1; + for (auto j = 0; j < y_rank; ++j) { + if (i == j) { + continue; + } + other_dims_range_lower_boundary *= value_range[j].first; + + } + int64_t cur_dim_range_max = static_cast( + (static_cast(range_max) + other_dims_range_lower_boundary - 1) / other_dims_range_lower_boundary); + if (value_range[i].second == -1) { + cur_dim_range_max = (cur_dim_range_max < INT32_MAX) ? cur_dim_range_max : INT32_MAX; + y_shape_range[i] = std::pair(value_range[i].first, cur_dim_range_max); + continue; + } + cur_dim_range_max = (cur_dim_range_max < value_range[i].second) ? cur_dim_range_max : value_range[i].second; + cur_dim_range_max = (cur_dim_range_max < INT32_MAX) ? cur_dim_range_max : INT32_MAX; + y_shape_range[i] = std::pair(value_range[i].first, cur_dim_range_max); + } + } + y_shape = GeShape(y_dims); + return GRAPH_SUCCESS; +} + +TEST_F(UtestGraphInferValueRangePass, reshape_infer_func_test_1) { + auto rank = 4; + std::vector> x_shape_range = {make_pair(1, 100), make_pair(1, 400)}; + GeShape x_shape = GeShape(std::vector(2, UNKNOWN_DIM)); + std::vector> shape_value_range = {make_pair(100, -1), make_pair(-1, -10), + make_pair(1, 20), make_pair(10, 10)}; + std::vector> y_shape_range; + GeShape y_shape = GeShape(std::vector(rank, UNKNOWN_DIM)); + auto ret = ReshapeRangeInferAllDims(x_shape_range, x_shape, shape_value_range, y_shape_range, y_shape); + EXPECT_EQ(ret, GRAPH_SUCCESS); + + std::vector target_y_shape_dims = {-1, -1, -1, 10}; + EXPECT_EQ(y_shape.GetDims(), target_y_shape_dims); + + std::vector target_y_shape_range = {100, 4000, 1, 40, 1, 20, 10, 10}; + std::vector output_shape_range; + for (auto pair : y_shape_range) { + output_shape_range.push_back(pair.first); + output_shape_range.push_back(pair.second); + } + EXPECT_EQ(output_shape_range, target_y_shape_range); +} + +TEST_F(UtestGraphInferValueRangePass, reshape_infer_func_test_2) { + auto rank = 4; + std::vector> x_shape_range = {make_pair(1, 100), make_pair(1, 400), make_pair(-1, -1)}; + GeShape x_shape = GeShape(std::vector(3, UNKNOWN_DIM)); + std::vector> shape_value_range = {make_pair(100, -1), make_pair(1, -10), + make_pair(1, 20), make_pair(10, 10)}; + std::vector> y_shape_range; + GeShape y_shape = GeShape(std::vector(rank, UNKNOWN_DIM)); + auto ret = ReshapeRangeInferAllDims(x_shape_range, x_shape, shape_value_range, y_shape_range, y_shape); + EXPECT_EQ(ret, GRAPH_SUCCESS); + + std::vector target_y_shape_dims = {-1, -1, -1, 10}; + EXPECT_EQ(y_shape.GetDims(), target_y_shape_dims); + + std::vector target_y_shape_range = {100, -1, 1, -1, 1, 20, 10, 10}; + std::vector output_shape_range; + for (auto pair : y_shape_range) { + output_shape_range.push_back(pair.first); + output_shape_range.push_back(pair.second); + } + EXPECT_EQ(output_shape_range, target_y_shape_range); +} + +TEST_F(UtestGraphInferValueRangePass, reshape_infer_func_test_3) { + auto rank = 4; + std::vector> x_shape_range = {}; + GeShape x_shape = GeShape(std::vector(3, 100)); + std::vector> shape_value_range = {make_pair(100, -1), make_pair(1, -10), + make_pair(1, 20), make_pair(10, 10)}; + std::vector> y_shape_range; + GeShape y_shape = GeShape(std::vector(rank, UNKNOWN_DIM)); + auto ret = ReshapeRangeInferAllDims(x_shape_range, x_shape, shape_value_range, y_shape_range, y_shape); + EXPECT_EQ(ret, GRAPH_SUCCESS); + + std::vector target_y_shape_dims = {-1, -1, -1, 10}; + EXPECT_EQ(y_shape.GetDims(), target_y_shape_dims); + + std::vector target_y_shape_range = {100, 100000, 1, 1000, 1, 20, 10, 10}; + std::vector output_shape_range; + for (auto pair : y_shape_range) { + output_shape_range.push_back(pair.first); + output_shape_range.push_back(pair.second); + } + EXPECT_EQ(output_shape_range, target_y_shape_range); +} +TEST_F(UtestGraphInferValueRangePass, reshape_infer_func_test_4) { + auto rank = 4; + std::vector> x_shape_range = {make_pair(1, 100), make_pair(0, 0)}; + GeShape x_shape = GeShape({-1, 0}); + std::vector> shape_value_range = {make_pair(0, 0), make_pair(-1, -10), + make_pair(10, 20), make_pair(100, 100)}; + std::vector> y_shape_range; + GeShape y_shape = GeShape(std::vector(rank, UNKNOWN_DIM)); + auto ret = ReshapeRangeInferAllDims(x_shape_range, x_shape, shape_value_range, y_shape_range, y_shape); + EXPECT_EQ(ret, GRAPH_SUCCESS); + + std::vector target_y_shape_dims = {0, -1, -1, 100}; + EXPECT_EQ(y_shape.GetDims(), target_y_shape_dims); + + std::vector target_y_shape_range = {0, 0, 1, -1, 10, 20, 100, 100}; + std::vector output_shape_range; + for (auto pair : y_shape_range) { + output_shape_range.push_back(pair.first); + output_shape_range.push_back(pair.second); + } + EXPECT_EQ(output_shape_range, target_y_shape_range); +} + +} // namespace ge