Browse Source

Pre Merge pull request !1789 from 王强/master

pull/1789/MERGE
王强 Gitee 4 years ago
parent
commit
8451d8fb9f
14 changed files with 1926 additions and 21 deletions
  1. +4
    -0
      ge/CMakeLists.txt
  2. +13
    -0
      ge/common/formats/utils/formats_trans_utils.cc
  3. +2
    -0
      ge/common/formats/utils/formats_trans_utils.h
  4. +15
    -11
      ge/graph/passes/constant_folding_pass.cc
  5. +5
    -0
      ge/graph/passes/constant_folding_pass.h
  6. +0
    -8
      ge/graph/passes/folding_pass.cc
  7. +0
    -2
      ge/graph/passes/folding_pass.h
  8. +585
    -0
      ge/graph/passes/infer_base_pass.cc
  9. +50
    -0
      ge/graph/passes/infer_base_pass.h
  10. +383
    -0
      ge/graph/passes/infer_value_range_pass.cc
  11. +45
    -0
      ge/graph/passes/infer_value_range_pass.h
  12. +3
    -0
      ge/graph/preprocess/graph_preprocess.cc
  13. +5
    -0
      tests/ut/ge/CMakeLists.txt
  14. +816
    -0
      tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc

+ 4
- 0
ge/CMakeLists.txt View File

@@ -297,7 +297,9 @@ set(TRAIN_SRC_LIST
"graph/passes/hccl_continuous_memcpy_pass.cc"
"graph/passes/identity_pass.cc"
"graph/passes/ref_identity_delete_op_pass.cc"
"graph/passes/infer_base_pass.cc"
"graph/passes/infershape_pass.cc"
"graph/passes/infer_value_range_pass.cc"
"graph/passes/iterator_op_pass.cc"
"graph/passes/link_gen_mask_nodes_pass.cc"
"graph/passes/merge_pass.cc"
@@ -546,7 +548,9 @@ set(INFER_SRC_LIST
"graph/passes/shape_operate_op_remove_pass.cc"
"graph/passes/assert_pass.cc"
"graph/passes/dropout_pass.cc"
"graph/passes/infer_base_pass.cc"
"graph/passes/infershape_pass.cc"
"graph/passes/infer_value_range_pass.cc"
"graph/passes/unused_const_pass.cc"
"graph/passes/permute_pass.cc"
"graph/passes/ctrl_edge_transfer_pass.cc"


+ 13
- 0
ge/common/formats/utils/formats_trans_utils.cc View File

@@ -49,6 +49,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const s
return JoinToString(shape);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
std::string RangeToString(const std::vector<std::pair<int64_t, int64_t>> &range) {
string serial_string;
serial_string += "[";
for (const auto &pair : range) {
serial_string += "{";
serial_string += std::to_string(pair.first) + "," + std::to_string(pair.second);
serial_string += "},";
}
serial_string += "]";
return serial_string;
}

int64_t GetItemNumByShape(const std::vector<int64_t> &shape) {
int64_t num = 1;
for (auto dim : shape) {


+ 2
- 0
ge/common/formats/utils/formats_trans_utils.h View File

@@ -54,6 +54,8 @@ std::string ShapeToString(const GeShape &shape);

std::string ShapeToString(const std::vector<int64_t> &shape);

std::string RangeToString(const std::vector<std::pair<int64_t, int64_t>> &range);

int64_t GetItemNumByShape(const std::vector<int64_t> &shape);

bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims);


+ 15
- 11
ge/graph/passes/constant_folding_pass.cc View File

@@ -20,17 +20,23 @@
#include "graph/operator_factory.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/type_utils.h"
#include "ge_local_engine/engine/host_cpu_engine.h"
#include "init/gelib.h"

namespace ge {
const int64_t kStartCallNum = 1;
const std::string kKernelLibName = "aicpu_tf_kernel";
// tf_kernel.json opsFlag config
const std::string kOpsFlagClose = "0";

Status RunOpKernelWithCheck(NodePtr &node,
const vector<ConstGeTensorPtr> &inputs,
std::vector<GeTensorPtr> &outputs) {
const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const {
return statistic_of_ge_constant_folding_;
}
const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const {
return statistic_of_op_constant_folding_;
}

Status ConstantFoldingPass::RunOpKernelWithCheck(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
std::vector<GeTensorPtr> &outputs) {
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized or is finalized.");
@@ -47,15 +53,13 @@ Status RunOpKernelWithCheck(NodePtr &node,
if (ops_flag == kOpsFlagClose) {
return UNSUPPORTED;
}
return FoldingPass::RunOpKernel(node, inputs, outputs);
return RunOpKernel(node, inputs, outputs);
}

const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const {
return statistic_of_ge_constant_folding_;
}

const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const {
return statistic_of_op_constant_folding_;
Status ConstantFoldingPass::RunOpKernel(NodePtr &node,
const vector<ConstGeTensorPtr> &inputs,
std::vector<GeTensorPtr> &outputs) {
return HostCpuEngine::GetInstance().Run(node, inputs, outputs);
}

Status ConstantFoldingPass::Run(ge::NodePtr &node) {


+ 5
- 0
ge/graph/passes/constant_folding_pass.h View File

@@ -28,6 +28,11 @@ class ConstantFoldingPass : public FoldingPass {
Status Run(ge::NodePtr &node) override;
const std::map<std::string, std::pair<std::uint64_t, uint64_t>> &GetGeConstantFoldingPerfStatistic() const;
const std::map<std::string, std::pair<std::uint64_t, uint64_t>> &GetOpConstantFoldingPerfStatistic() const;

static Status RunOpKernel(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, vector<GeTensorPtr> &outputs);
static Status RunOpKernelWithCheck(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
std::vector<GeTensorPtr> &outputs);

private:
std::map<std::string, std::pair<std::uint64_t, uint64_t>> statistic_of_op_constant_folding_;
std::map<std::string, std::pair<std::uint64_t, uint64_t>> statistic_of_ge_constant_folding_;


+ 0
- 8
ge/graph/passes/folding_pass.cc View File

@@ -28,8 +28,6 @@
#include "inc/kernel.h"
#include "inc/kernel_factory.h"
#include "graph/debug/ge_attr_define.h"
#include "ge_local_engine/engine/host_cpu_engine.h"


namespace ge {
namespace folding_pass {
@@ -123,12 +121,6 @@ NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tens
}
} // namespace

Status FoldingPass::RunOpKernel(NodePtr &node,
const vector<ConstGeTensorPtr> &inputs,
std::vector<GeTensorPtr> &outputs) {
return HostCpuEngine::GetInstance().Run(node, inputs, outputs);
}

Status FoldingPass::Folding(NodePtr &node, vector<GeTensorPtr> &outputs) {
GE_CHECK_NOTNULL(node);
GELOGD("begin folding node:%s", node->GetName().c_str());


+ 0
- 2
ge/graph/passes/folding_pass.h View File

@@ -34,8 +34,6 @@ bool IsNoNeedConstantFolding(const NodePtr &node);
using IndexsToAnchors = std::map<int, std::vector<InDataAnchorPtr>>;

class FoldingPass : public BaseNodePass {
public:
static Status RunOpKernel(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, vector<GeTensorPtr> &outputs);
protected:
Status Folding(NodePtr &node, vector<GeTensorPtr> &outputs);
private:


+ 585
- 0
ge/graph/passes/infer_base_pass.cc View File

@@ -0,0 +1,585 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "infer_base_pass.h"
#include "common/ge/ge_util.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_util.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"

namespace ge {
namespace {
void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)desc->GetShapeRange(shape_range);
desc_str += formats::RangeToString(shape_range);
shape_range.clear();
(void)desc->GetOriginShapeRange(shape_range);
desc_str += ",";
desc_str += formats::RangeToString(shape_range);
shape_range.clear();
}

graphStatus FindSubgraphDataAndNetoutput(const ComputeGraphPtr &sub_graph, NodePtr &netoutput, const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
}
if (sub_node->GetType() == DATA) {
if (sub_node->GetOpDesc() == nullptr) {
return GRAPH_FAILED;
}

int ref_i;
if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
return GRAPH_FAILED;
}
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!",
sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!",
sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
return GRAPH_FAILED;
}
ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
}
}
return GRAPH_SUCCESS;
}
} // namespace

Status InferBasePass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());

bool need_infer = NeedInfer(node);
if (!need_infer) {
GELOGD("Node %s does not need to infer.", node->GetName().c_str());
return SUCCESS;
}

std::set<NodePtr> changed_nodes;
auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes);
if (ret != GRAPH_SUCCESS) {
GELOGE(ret, "Infer and update for node %s failed! ret: %u", node->GetName().c_str(), ret);
return GRAPH_FAILED;
}

AddChangedNodesImmediateRepass(changed_nodes);
return SUCCESS;
}

bool InferBasePass::NeedInfer(const NodePtr &node) { return true; }
void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) {
for (const auto &node_ele : changed_nodes) {
AddImmediateRePassNode(node_ele);
}
}

graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {
graphStatus ret ;
bool contain_subgraph = ContainsSubgraph(node);
if (contain_subgraph && before_subgraph) {
ret = UpdateTensorDescToSubgraphData(node, changed_nodes);
if (ret != GRAPH_SUCCESS) {
GELOGE(ret, "Update subgraph data tensor desc for node %s failed! ret: %u", node->GetName().c_str(), ret);
return ret;
}
}
ret = Infer(node);
if (ret != GRAPH_SUCCESS) {
GELOGE(ret, "Infer failed for node %s, ret: %u", node->GetName().c_str(), ret);
return ret;
}
if (contain_subgraph && !before_subgraph) {
ret = UpdateTensorDescToParentNode(node, changed_nodes);
if (ret != GRAPH_SUCCESS) {
GELOGE(ret, "Update parent tensor desc for node %s failed! ret: %u", node->GetName().c_str(), ret);
return ret;
}
}

ret = UpdateTensorDescToPeerInputs(node, changed_nodes);
if (ret != GRAPH_SUCCESS) {
GELOGE(ret, "Node %s updates tensor desc to peer input nodes failed! ret: %u", node->GetName().c_str(), ret);
}
return ret;
}

bool InferBasePass::ContainsSubgraph(const NodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
return false;
}

auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
if (root_graph == nullptr) {
return false;
}
for (const auto &name : sub_graph_names) {
if (name.empty()) {
continue;
}
auto sub_graph = root_graph->GetSubgraph(name);
if (sub_graph != nullptr) {
return true;
}
}
return false;
}

graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes) {
PrintInOutTensorShape(node, "after_infer");
auto op_desc = node->GetOpDesc();
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc();
if (peer_anchor_opdesc == nullptr) {
continue;
}
auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx());
if (peer_input_desc == nullptr) {
continue;
}

bool changed = false;
auto ret = UpdatePeerInputDesc(output_tensor, peer_input_desc, changed);
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update peer input desc failed, node %s.", node->GetName().c_str());
GELOGE(ret, "Update peer input desc failed, node %s.", node->GetName().c_str());
return ret;
}
if (changed) {
changed_nodes.insert(peer_anchor->GetOwnerNode());
}
}
}
return GRAPH_SUCCESS;
}

graphStatus InferBasePass::UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
changed = false;
return GRAPH_SUCCESS;
}

std::vector<ComputeGraphPtr> InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) {
std::vector<ComputeGraphPtr> cur_node_subgraph;
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
return cur_node_subgraph;
}

auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
for (const auto &name : sub_graph_names) {
if (name.empty()) {
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
continue;
}
auto sub_graph = root_graph->GetSubgraph(name);
if (sub_graph == nullptr) {
REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
continue;
}
cur_node_subgraph.emplace_back(sub_graph);
}
return cur_node_subgraph;
}

graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes) {
// if infer again, update output of while into subgraph data node
auto op_desc = node->GetOpDesc();
for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
for (const auto &node_sub : sub_graph->GetDirectNode()) {
if (node_sub->GetType() != DATA) {
continue;
}
auto name = sub_graph->GetName();
int ref_i;
auto data_opdesc = node_sub->GetOpDesc();
if (data_opdesc == nullptr) {
REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
node->GetName().c_str());
GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
name.c_str(), node->GetName().c_str());
GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
// In multi-batch, data shape of subgraph is different, no need to refresh.
if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
continue;
}
auto input_desc = op_desc->MutableInputDesc(ref_i);
if (input_desc == nullptr) {
REPORT_INNER_ERROR("E19999",
"The ref index(%d) on the data %s on the sub graph %s "
"parent node %s are incompatible, inputs num %u",
ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
node->GetAllInDataAnchorsSize());
GE_LOGE(
"[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s "
"parent node %s are incompatible, inputs num %u",
ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize());
return GRAPH_FAILED;
}
GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
node->GetName().c_str());

auto data_input_desc = data_opdesc->MutableInputDesc(0);
if (!SameTensorDesc(input_desc, data_input_desc)) {
changed_nodes.insert(node_sub);
// if need infer again, refresh while subgraph input with while output
if (node->GetType() == WHILE) {
input_desc = op_desc->MutableOutputDesc(ref_i);
}
}

auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s",
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(),
name.c_str(), node->GetName().c_str());
return ret;
}

ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s",
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed",
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
return ret;
}
}
}
return GRAPH_SUCCESS;
}

graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes) {
std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());

for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
auto name = sub_graph->GetName();
NodePtr netoutput = nullptr;
auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
if (ret != GRAPH_SUCCESS) {
return ret;
}
if (netoutput == nullptr) {
REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(),
node->GetName().c_str());
GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
auto netoutput_opdesc = netoutput->GetOpDesc();
if (netoutput_opdesc == nullptr) {
REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it",
name.c_str(), node->GetName().c_str());
GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
if (edge_desc == nullptr) {
REPORT_INNER_ERROR("E19999",
"Invalid NetOutput node on sub graph %s, parent node %s, "
"can not find input tensor %d",
name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
return GRAPH_FAILED;
}
GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(),
edge_desc->GetShape().GetDimNum());
int ref_i;
if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
// if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
continue;
}
GELOGI("Parent node index of edge desc is %d", ref_i);
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
return GRAPH_FAILED;
}
ref_out_tensors[ref_i].emplace_back(*edge_desc);
}
}

if (node->GetType() == WHILE) {
return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes);
}
return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes);
}

graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
std::set<NodePtr> &changed_nodes) {
GELOGD("Enter update parent node shape for class while op process");
if (ref_data_tensors.size() != ref_out_tensors.size()) {
REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!",
node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(),
ref_out_tensors.size());
GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!",
node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size());
return GRAPH_FAILED;
}

// check input and output
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].size() != 1) {
REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!");
GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!");
return GRAPH_FAILED;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
for (auto &tensor : ref_data_tensors[i]) {
// ref_i's data and output tensor shape should be same
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output",
node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.",
node->GetName().c_str());
return GRAPH_FAILED;
}
auto data_shape = tensor.MutableShape();
auto out_shape = ref_out_tensor.MutableShape();
if (data_shape.GetDims() != out_shape.GetDims()) {
GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.",
node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str());
if (data_shape.GetDimNum() != out_shape.GetDimNum()) {
ref_out_tensor.SetUnknownDimNumShape();
} else {
size_t data_dim_num = data_shape.GetDimNum();
std::vector<std::pair<int64_t, int64_t>> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)};
for (size_t j = 0; j < data_dim_num; ++j) {
if (data_shape.GetDim(j) != out_shape.GetDim(j)) {
data_shape.SetDim(j, UNKNOWN_DIM);
}
if (data_shape.GetDim(j) != UNKNOWN_DIM) {
data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j));
}
}
ref_out_tensor.SetShape(data_shape);
ref_out_tensor.SetShapeRange(data_shape_range);
}
}
}

auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
bool output_changed = SameTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc);
if (output_changed) {
changed_nodes.insert(node);
}
}
return GRAPH_SUCCESS;
}

graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
std::set<NodePtr> &changed_nodes) {
// check sub_graph shape. Get max for update.
for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
if (ref_out_tensors[i].empty()) {
continue;
}

int64_t max_size = 0;
size_t max_shape_index = 0;
auto &ref_out_tensor = ref_out_tensors[i].at(0);
for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
auto &tensor = ref_out_tensors[i].at(j);
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output",
node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output",
node->GetName().c_str());
return GRAPH_FAILED;
}

auto shape = tensor.MutableShape();
int64_t size = 1;
for (auto dim : shape.GetDims()) {
if (dim != 0 && INT64_MAX / dim < size) {
REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(),
node->GetName().c_str());
GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
return PARAM_INVALID;
}
size *= dim;
}

if (size > max_size) {
max_size = size;
max_shape_index = j;
}
}

auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
bool output_changed =
SameTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensors[i].at(max_shape_index)), output_desc);
if (output_changed) {
changed_nodes.insert(node);
}
}

return GRAPH_SUCCESS;
}

graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
std::set<NodePtr> &changed_nodes) {
GELOGD("Enter update parent node shape for class branch op process");
if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes);
}

// check sub_graph shape.If not same ,do unknown shape process
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
for (auto &tensor : ref_out_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s",
node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
break;
}
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
continue;
}
GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(),
i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
}
}

auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
bool output_changed =
SameTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc);
if (output_changed) {
changed_nodes.insert(node);
}
}
return GRAPH_SUCCESS;
}

void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) {
if (!IsLogEnable(GE, DLOG_DEBUG)) {
return;
}
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid");
GELOGE(GRAPH_FAILED, "[Check][Param] node is null");
return;
}
ge::OpDescPtr op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid");
GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return );
std::stringstream ss;
ss << "{";
int32_t in_idx = 0;
int32_t out_idx = 0;
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
if (input_desc == nullptr) {
in_idx++;
continue;
}
if (in_idx > 0) {
ss << " ";
}
ss << "input_" << in_idx << " "
<< "tensor: [";
ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
string range_str;
SerialShapeRange(input_desc, range_str);
ss << "(shape_range:" << range_str << "),";
std::vector<std::pair<int64_t, int64_t>> value_range;
(void)input_desc->GetValueRange(value_range);
string value_range_str = formats::RangeToString(value_range);
ss << "(value_range:" << value_range_str << ")]";
in_idx++;
}
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
if (output_desc == nullptr) {
out_idx++;
continue;
}
ss << " ";
ss << "output_" << out_idx << " "
<< "tensor: [";
ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),";
ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),";
ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),";
ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),";
ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),";
ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),";
string range_str;
SerialShapeRange(output_desc, range_str);
ss << "(shape_range:" << range_str << "),";
std::vector<std::pair<int64_t, int64_t>> value_range;
(void)output_desc->GetValueRange(value_range);
string value_range_str = formats::RangeToString(value_range);
ss << "(value_range:" << value_range_str << ")]";
out_idx++;
}
ss << "}";
GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str());
}
} // namespace ge

+ 50
- 0
ge/graph/passes/infer_base_pass.h View File

@@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_INFER_BASE_PASS_H_
#define GE_GRAPH_PASSES_INFER_BASE_PASS_H_

#include "graph/passes/base_pass.h"

namespace ge {
class InferBasePass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
graphStatus InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes);
void PrintInOutTensorShape(const NodePtr &node, const std::string &phase);

protected:
virtual bool NeedInfer(const NodePtr &node);
virtual graphStatus Infer(NodePtr &node) = 0;
virtual bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) = 0;
virtual graphStatus UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0;

private:
void AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes);
bool ContainsSubgraph(const NodePtr &node);
std::vector<ComputeGraphPtr> GetCurNodeSubgraphs(const NodePtr &node);
graphStatus UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes);
graphStatus UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes);
graphStatus UpdateParentNodeForWhile(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
std::set<NodePtr> &changed_nodes);
graphStatus UpdateParentNodeForBranch(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
std::set<NodePtr> &changed_nodes);
graphStatus UpdateOutputForMultiBatch(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
std::set<NodePtr> &changed_nodes);
graphStatus UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_INFER_BASE_PASS_H_

+ 383
- 0
ge/graph/passes/infer_value_range_pass.cc View File

@@ -0,0 +1,383 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/infer_value_range_pass.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/operator_factory_impl.h"
#include "graph/passes/constant_folding_pass.h"
#include "graph/utils/type_utils.h"
#include "common/ge/ge_util.h"

using std::unique_ptr;
namespace ge {
namespace {
#define GET_DATA_BY_DTYPE(DTYPE, TYPE) \
case (DTYPE): \
ConstructValueRange<TYPE>(lower_tensor, higher_tensor, output_tensor_value_range); \
break;

Status RunCpuKernelForValueRange(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
std::vector<GeTensorPtr> &outputs) {
// should use RunOpKernelWithCheck, RunOpKernel for ut test
auto ret = ConstantFoldingPass::RunOpKernel(node, inputs, outputs);
if (ret != SUCCESS) {
auto op_kernel = folding_pass::GetKernelByType(node);
if (op_kernel == nullptr) {
GELOGW("Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(),
node->GetType().c_str());
return NOT_CHANGED;
}

ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs);
if (ret != SUCCESS) {
GELOGW("Calculate for node %s failed in constant folding", node->GetName().c_str());
return NOT_CHANGED;
}
}
GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str());
return SUCCESS;
}
} // namespace

graphStatus InferValueRangePass::Infer(NodePtr &node) {
PrintInOutTensorShape(node, "before_infer_value_range");
auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType());

// Use registered func to calculate value range
if (!infer_value_range_param.use_cpu_kernel) {
if (infer_value_range_param.infer_value_func == nullptr) {
GELOGW("The registered func of node %s to infer value range is nullptr.", node->GetName().c_str());
return GRAPH_NOT_CHANGED;
}
Operator op = OpDescUtils::CreateOperatorFromNode(node);
auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op);
if (ret != GRAPH_SUCCESS) {
GELOGW("Node %s call infer value range func failed, ret: %u.", node->GetName().c_str(), ret);
return GRAPH_NOT_CHANGED;
}
return GRAPH_SUCCESS;
}

// Use CPU kernel func to calculate value range
auto ret = ConstructInputAndInferValueRange(node);
if (ret != GRAPH_SUCCESS) {
GELOGW("Use CPU kernel to calculate value range failed. node: %s, ret: %u", node->GetName().c_str(), ret);
return GRAPH_NOT_CHANGED;
}
return GRAPH_SUCCESS;
}

bool InferValueRangePass::NeedInfer(const NodePtr &node) {
auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType());
if (!infer_value_range_param.is_initialized) {
GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.",
node->GetName().c_str());
return false;
}

if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) {
// Only do infer for node that all inputs are dynamic, such as shape
if (InputIsDynamic(node)) {
return true;
}
GELOGD("Node %s register func to infer value range and when_call is INPUT_IS_DYNAMIC, but check input failed.",
node->GetName().c_str());
} else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) {
// Only do infer for node that all inputs have value_range or node type of inputs is constant/const
if (InputIsConstOrHasValueRange(node)) {
return true;
}
GELOGD("Node %s register func to infer value range and when_call is INPUT_HAS_VALUE_RANGE, but check input failed.",
node->GetName().c_str());
}
GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str());
return false;
}

bool InferValueRangePass::InputIsDynamic(const NodePtr &node) {
bool input_is_dynamic = false;
auto cur_op_desc = node->GetOpDesc();
for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) {
auto dims = input_desc->GetShape().GetDims();
for (auto dim : dims) {
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
input_is_dynamic = true;
break;
}
}
}
return input_is_dynamic;
}

bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) {
bool input_is_const_or_has_value_range = true;
auto cur_op_desc = node->GetOpDesc();
auto in_data_anchors = node->GetAllInDataAnchors();
for (size_t i = 0; i < in_data_anchors.size(); ++i) {
auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
continue;
}
auto peer_node = peer_out_anchor->GetOwnerNode();
if (peer_node == nullptr || peer_node->GetOpDesc() == nullptr) {
continue;
}
if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) {
continue;
}

const auto &input_desc = cur_op_desc->GetInputDesc(i);
std::vector<std::pair<int64_t, int64_t>> value_range;
(void)input_desc.GetValueRange(value_range);
if (value_range.empty()) {
GELOGD("Node %s input %zu does not have value range, skip infer_value_range_pass for current node.",
node->GetName().c_str(), i);
input_is_const_or_has_value_range = false;
break;
}
}
return input_is_const_or_has_value_range;
}


bool InferValueRangePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) {
bool same_desc = true;
std::vector<std::pair<int64_t, int64_t>> src_value_range;
std::vector<std::pair<int64_t, int64_t>> dst_value_range;
(void)src->GetValueRange(src_value_range);
(void)dst->GetValueRange(dst_value_range);
if (src_value_range != dst_value_range) {
same_desc = false;
}
return same_desc;
}

graphStatus InferValueRangePass::UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
changed = false;
std::vector<std::pair<int64_t, int64_t>> src_value_range;
std::vector<std::pair<int64_t, int64_t>> dst_value_range;
(void)src->GetValueRange(src_value_range);
(void)dst->GetValueRange(dst_value_range);
if (src_value_range != dst_value_range) {
changed = true;
}

dst->SetValueRange(src_value_range);
return GRAPH_SUCCESS;
}

template <typename T>
graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value,
GeTensorPtr &output_ptr) {
std::vector<std::pair<int64_t, int64_t>> value_range;
(void)tensor_desc.GetValueRange(value_range);
if (static_cast<int64_t>(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) {
REPORT_INNER_ERROR("E19999", "Value range of input %s is invalid.", tensor_desc.GetName().c_str());
GELOGE(GRAPH_PARAM_INVALID, "Value range of input %s is invalid.", tensor_desc.GetName().c_str());
return GRAPH_PARAM_INVALID;
}

size_t value_range_data_num = value_range.size();
unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]());
if (buf == nullptr) {
REPORT_INNER_ERROR("E19999", "New buf failed");
GELOGE(MEMALLOC_FAILED, "new buf failed");
return GRAPH_FAILED;
}
for (size_t j = 0; j < value_range_data_num; ++j) {
auto value_range_j = use_floor_value ? value_range[j].first : value_range[j].second;
buf[j] = static_cast<T>(value_range_j);
}

if (output_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), value_range_data_num * sizeof(T)) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "set data failed");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}

graphStatus InferValueRangePass::ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value,
GeTensorPtr &output_ptr) {
graphStatus ret = GRAPH_SUCCESS;
auto data_type = tensor_desc.GetDataType();
output_ptr->MutableTensorDesc().SetDataType(data_type);
switch (data_type) {
case DT_FLOAT:
ret = ConstructData<float>(tensor_desc, use_floor_value, output_ptr);
break;
case DT_DOUBLE:
ret = ConstructData<double>(tensor_desc, use_floor_value, output_ptr);
break;
case DT_UINT8:
ret = ConstructData<uint8_t>(tensor_desc, use_floor_value, output_ptr);
break;
case DT_INT8:
ret = ConstructData<int8_t>(tensor_desc, use_floor_value, output_ptr);
break;
case DT_UINT16:
ret = ConstructData<uint16_t>(tensor_desc, use_floor_value, output_ptr);
break;
case DT_INT16:
ret = ConstructData<int16_t>(tensor_desc, use_floor_value, output_ptr);
break;
case DT_INT32:
ret = ConstructData<int32_t>(tensor_desc, use_floor_value, output_ptr);
break;
case DT_INT64:
ret = ConstructData<int64_t>(tensor_desc, use_floor_value, output_ptr);
break;
default:
GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str());
ret = GRAPH_FAILED;
}
return ret;
}

vector<ConstGeTensorPtr> InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) {
vector<ConstGeTensorPtr> input_tensors;
auto cur_op_desc = node->GetOpDesc();
auto in_data_anchors = node->GetAllInDataAnchors();
for (size_t i = 0; i < in_data_anchors.size(); ++i) {
auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
continue;
}
auto peer_node = peer_out_anchor->GetOwnerNode();
if (peer_node == nullptr) {
continue;
}

// construct input tensor by constant node
if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) {
vector<GeTensorPtr> const_weight = OpDescUtils::MutableWeights(peer_node);
if (const_weight.empty()) {
REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight is empty, node: %s(%s)",
peer_node->GetName().c_str(), peer_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(),
peer_node->GetType().c_str());
return vector<ConstGeTensorPtr>();
}
// const/constant op has only one weight
if (const_weight.at(0) == nullptr) {
REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight of constant is null, node: %s(%s)",
peer_node->GetName().c_str(), peer_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight of constant is null, node name: %s(%s)",
peer_node->GetName().c_str(), peer_node->GetType().c_str());
return vector<ConstGeTensorPtr>();
}
input_tensors.push_back(const_weight.at(0));
continue;
}

// construct input tensor by boundary of value range
const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i);
GeTensorPtr tmp_tensor_ptr = MakeShared<GeTensor>(input_tensor_desc);
if (tmp_tensor_ptr == nullptr) {
REPORT_INNER_ERROR("E19999", "Make shared failed");
GELOGE(MEMALLOC_FAILED, "Make shared failed");
return vector<ConstGeTensorPtr>();
}

auto ret = ConstructDataByType(input_tensor_desc, use_floor_value, tmp_tensor_ptr);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERROR("E19999", "Input %s construct input tensor by boundary of value range failed.",
input_tensor_desc.GetName().c_str());
GELOGE(GRAPH_PARAM_INVALID, "Input %s construct input tensor by boundary of value range failed.",
input_tensor_desc.GetName().c_str());
return vector<ConstGeTensorPtr>();
}
input_tensors.push_back(tmp_tensor_ptr);
}

return input_tensors;
}

graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) {
auto inputs = ConstructInputTensors(node, true);
if (inputs.empty()) {
return GRAPH_PARAM_INVALID;
}
vector<GeTensorPtr> outputs_lower;
auto ret = RunCpuKernelForValueRange(node, inputs, outputs_lower);
if (ret != SUCCESS) {
REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str());
GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str());
return GRAPH_FAILED;
}

inputs = ConstructInputTensors(node, false);
if (inputs.empty()) {
return GRAPH_PARAM_INVALID;
}
vector<GeTensorPtr> outputs_higher;
ret = RunCpuKernelForValueRange(node, inputs, outputs_higher);
if (ret != SUCCESS) {
REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str());
GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str());
return GRAPH_FAILED;
}

// construct value range from output tensor
OpDescPtr node_desc = node->GetOpDesc();
std::vector<std::pair<int64_t, int64_t>> output_tensor_value_range;
size_t node_output_desc_size = node_desc->GetOutputsSize();
for (size_t i = 0; i < node_output_desc_size; ++i) {
output_tensor_value_range.clear();
auto lower_tensor = outputs_lower[i];
auto lower_tensor_shape_size = lower_tensor->GetTensorDesc().GetShape().GetShapeSize();
auto higher_tensor = outputs_higher[i];
auto higher_tensor_shape_size = higher_tensor->GetTensorDesc().GetShape().GetShapeSize();
auto output_tensor_desc = node_desc->MutableOutputDesc(i);
auto output_tensor_shape_size = output_tensor_desc->GetShape().GetShapeSize();
if (output_tensor_shape_size != lower_tensor_shape_size || output_tensor_shape_size != higher_tensor_shape_size) {
GELOGE(GRAPH_PARAM_INVALID, "Value range of output %s is invalid.", output_tensor_desc->GetName().c_str());
}

auto data_type = output_tensor_desc->GetDataType();
switch (data_type) {
GET_DATA_BY_DTYPE(DT_INT8, int8_t)
GET_DATA_BY_DTYPE(DT_INT16, int16_t)
GET_DATA_BY_DTYPE(DT_INT32, int32_t)
GET_DATA_BY_DTYPE(DT_INT64, int64_t)
GET_DATA_BY_DTYPE(DT_UINT8, uint8_t)
GET_DATA_BY_DTYPE(DT_UINT16, uint16_t)
GET_DATA_BY_DTYPE(DT_UINT32, uint32_t)
GET_DATA_BY_DTYPE(DT_UINT64, uint64_t)
GET_DATA_BY_DTYPE(DT_FLOAT, float)
GET_DATA_BY_DTYPE(DT_DOUBLE, double)
default:
GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str());
return GRAPH_FAILED;
}
output_tensor_desc->SetValueRange(output_tensor_value_range);
}
return GRAPH_SUCCESS;
}

template <typename T>
void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor,
std::vector<std::pair<int64_t, int64_t>> &value_range) {
auto x = reinterpret_cast<const T *>(left_tensor->GetData().GetData());
auto y = reinterpret_cast<const T *>(right_tensor->GetData().GetData());
for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) {
auto left = static_cast<int64_t>(*(x + j));
auto right = static_cast<int64_t>(*(y + j));
value_range.emplace_back(std::make_pair(left, right));
}
}
} // namespace ge

+ 45
- 0
ge/graph/passes/infer_value_range_pass.h View File

@@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_
#define GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_

#include "graph/passes/infer_base_pass.h"

namespace ge {
class InferValueRangePass : public InferBasePass {
public:
graphStatus Infer(NodePtr &node) override;

protected:
bool NeedInfer(const NodePtr &node) override;
bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) override;
graphStatus UpdatePeerInputDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override;

private:
bool InputIsDynamic(const NodePtr &node);
bool InputIsConstOrHasValueRange(const NodePtr &node);
template <typename T>
graphStatus ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr);
graphStatus ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr);
vector<ConstGeTensorPtr> ConstructInputTensors(const NodePtr &node, bool use_floor_value);
template <typename T>
void ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor,
std::vector<std::pair<int64_t, int64_t>> &value_range);
graphStatus ConstructInputAndInferValueRange(NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_

+ 3
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -54,6 +54,7 @@
#include "graph/passes/hccl_group_pass.h"
#include "graph/passes/identity_pass.h"
#include "graph/passes/infershape_pass.h"
#include "graph/passes/infer_value_range_pass.h"
#include "graph/passes/merge_pass.h"
#include "graph/passes/net_output_pass.h"
#include "graph/passes/no_use_reshape_remove_pass.h"
@@ -1997,6 +1998,8 @@ Status GraphPrepare::InferShapeForPreprocess() {
names_to_passes.emplace_back("MergePass", &merge_pass);
InferShapePass infer_shape_pass;
names_to_passes.emplace_back("InferShapePass", &infer_shape_pass);
InferValueRangePass infer_value_pass;
names_to_passes.emplace_back("InferValuePass", &infer_value_pass);
ReplaceWithEmptyConstPass replace_with_empty_const_pass;
names_to_passes.emplace_back("ReplaceWithEmptyConstPass", &replace_with_empty_const_pass);
DimensionComputePass dimension_compute_pass;


+ 5
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -220,7 +220,9 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc"
@@ -533,7 +535,9 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc"
"${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc"
"${GE_CODE_DIR}/ge/analyzer/analyzer.cc"
"${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc"
@@ -703,6 +707,7 @@ set(PASS_TEST_FILES
"graph/passes/net_output_pass_unittest.cc"
"graph/passes/no_use_reshape_remove_pass_unittest.cc"
"graph/passes/infershape_pass_unittest.cc"
"graph/passes/infer_value_range_pass_unittest.cc"
"graph/passes/mark_force_unknown_for_cond_pass_unittest.cc"
"graph/passes/multi_batch_clone_pass_unittest.cc"
"graph/passes/subgraph_const_migration_pass_unittest.cc"


+ 816
- 0
tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc View File

@@ -0,0 +1,816 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#define protected public
#define private public
#include "graph/passes/infer_value_range_pass.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph_builder_utils.h"

#include "inc/external/graph/operator_reg.h"
#include "inc/external/graph/operator.h"
#include "inc/external/graph/operator_factory.h"
#include "inc/graph/operator_factory_impl.h"
#include "inc/kernel.h"
#include "inc/kernel_factory.h"

using namespace std;
using namespace testing;
namespace ge {
class UtestGraphInferValueRangePass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

/*
* data1 const1
* \ /
* case1
* |
* relu10
* |
* netoutput
*/
ut::GraphBuilder ParentGraphBuilder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1", "Data", 0, 1);
std::vector<int64_t> const_shape = {1};
auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_INT32, const_shape);
auto case1 = builder.AddNode("case1", CASE, 2, 1);
auto relu1 = builder.AddNode("relu10", "Relu", 1, 1);
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);

int32_t weight[1] = {1};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(const1, {tensor});
auto case_in0_shape = GeShape({1,1,-1,224});
case1->GetOpDesc()->MutableInputDesc(0)->SetShape(case_in0_shape);
std::vector<std::pair<int64_t, int64_t>> in_range = {make_pair(1, 1), make_pair(1, 1),
make_pair(1, -1), make_pair(1, 224)};
case1->GetOpDesc()->MutableInputDesc(0)->SetValueRange(in_range);
auto case_in1_shape = GeShape({1,1});
case1->GetOpDesc()->MutableInputDesc(1)->SetShape(case_in1_shape);

builder.AddDataEdge(data1, 0, case1, 0);
builder.AddDataEdge(const1, 0, case1, 1);
builder.AddDataEdge(case1, 0, relu1, 0);
builder.AddDataEdge(relu1, 0, netoutput, 0);
return builder;
}

/*
* data1 data2
* \ /
* switch
* / \
* relu1 relu2
* \ /
* merge
* |
* netoutput
*/
ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) {
ut::GraphBuilder builder = ut::GraphBuilder(graph_name);

std::vector<int64_t> shape1 = {2,2};
string data1_name = "data1_" + std::to_string(num);
auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1);
auto data1_desc = data1->GetOpDesc();
EXPECT_NE(data1_desc, nullptr);
AttrUtils::SetInt(data1_desc, "_parent_node_index", 0);

std::vector<int64_t> shape2 = {3,3};
string data2_name = "data2_" + std::to_string(num);
auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2);
auto data2_desc = data2->GetOpDesc();
EXPECT_NE(data2_desc, nullptr);
AttrUtils::SetInt(data2_desc, "_parent_node_index", 1);

string switch_name = "switch_" + std::to_string(num);
auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2);

string relu1_name = "relu1_" + std::to_string(num);
auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1);

string relu2_name = "relu2_" + std::to_string(num);
auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1);

string merge_name = "merge_" + std::to_string(num);
auto merge = builder.AddNode(merge_name, "Merge", 2, 1);

std::vector<int64_t> shape7 = {8,8};
string output_name = "output_" + std::to_string(num);
auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7);
auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0);
EXPECT_NE(input0_desc, nullptr);
AttrUtils::SetInt(input0_desc, "_parent_node_index", 0);
std::vector<std::pair<int64_t, int64_t>> range = {make_pair(1, -1), make_pair(1, -1)};
input0_desc->SetValueRange(range);

builder.AddDataEdge(data1, 0, switch1, 0);
builder.AddDataEdge(data2, 0, switch1, 1);
builder.AddDataEdge(switch1, 0, relu1, 0);
builder.AddDataEdge(switch1, 1, relu2, 0);
builder.AddDataEdge(relu1, 0, merge, 0);
builder.AddDataEdge(relu2, 0, merge, 1);
builder.AddDataEdge(merge, 0, netoutput, 0);

return builder;
}

void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) {
auto case_node = parent_graph->FindNode("case1");
EXPECT_NE(case_node, nullptr);

for (uint32_t i = 0; i < branch_num; ++i) {
string name = "Branch_Graph_" + std::to_string(i);

auto builder_subgraph = SwitchSubgraphBuilder(name, i);
auto switch_subgraph = builder_subgraph.GetGraph();

case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName());
case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName());

switch_subgraph->SetParentNode(case_node);
switch_subgraph->SetParentGraph(parent_graph);
EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS);
}
}

TEST_F(UtestGraphInferValueRangePass, infer_pass_not_register) {
auto graph = std::make_shared<ComputeGraph>("test_graph");
GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_FLOAT16);
auto addn_op_desc = std::make_shared<OpDesc>("AddN", "AddN");
addn_op_desc->AddInputDesc(ge_tensor_desc);
addn_op_desc->AddOutputDesc(ge_tensor_desc);
auto addn_op_node = graph->AddNode(addn_op_desc);

InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(addn_op_node), SUCCESS);
}

auto ShapeValueInfer = [&](Operator &op) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto output_tensor_desc = op_desc->MutableOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> in_shape_range;
op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range);
if (!in_shape_range.empty()) {
output_tensor_desc->SetValueRange(in_shape_range);
}
return SUCCESS;
};
REG_OP(Shape)
.OP_END_FACTORY_REG(Shape)
IMPL_INFER_VALUE_RANGE_FUNC(Shape, ShapeValueRangeFunc){
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto output_tensor_desc = op_desc->MutableOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> in_shape_range;
op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range);
if (!in_shape_range.empty()) {
output_tensor_desc->SetValueRange(in_shape_range);
}
return GRAPH_SUCCESS;
}

TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_1_not_infer) {
INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueRangeFunc);
auto graph = std::make_shared<ComputeGraph>("test_graph");
GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_INT32);
std::vector<std::pair<int64_t, int64_t>> shape_range = {make_pair(1, 1), make_pair(1, 1),
make_pair(4, 4), make_pair(192, 192)};
ge_tensor_desc.SetShapeRange(shape_range);
GeTensorDesc output_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32);
auto op_desc = std::make_shared<OpDesc>("Shape", "Shape");
op_desc->AddInputDesc(ge_tensor_desc);
op_desc->AddOutputDesc(output_tensor_desc);
auto op_node = graph->AddNode(op_desc);

InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(op_node), SUCCESS);

auto output_0_desc = op_node->GetOpDesc()->GetOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> value_range;
output_0_desc.GetValueRange(value_range);
EXPECT_EQ(value_range.empty(), true);
}

TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_1_infer) {
// sqrt -> shape -> Output
INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueRangeFunc);
auto graph = std::make_shared<ComputeGraph>("test_graph");
GeTensorDesc sqrt_tensor_desc(GeShape({-1, -1, 4, 192}), ge::FORMAT_NCHW, DT_INT32);
std::vector<std::pair<int64_t, int64_t>> shape_range = {make_pair(1, 100), make_pair(1, 240),
make_pair(4, 4), make_pair(192, 192)};
sqrt_tensor_desc.SetShapeRange(shape_range);
auto sqrt_op_desc = std::make_shared<OpDesc>("Sqrt", "Sqrt");
sqrt_op_desc->AddInputDesc(sqrt_tensor_desc);
sqrt_op_desc->AddOutputDesc(sqrt_tensor_desc);
auto sqrt_node = graph->AddNode(sqrt_op_desc);

GeTensorDesc shape_output_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32);
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape");
shape_op_desc->AddInputDesc(sqrt_tensor_desc);
shape_op_desc->AddOutputDesc(shape_output_desc);
auto shape_node = graph->AddNode(shape_op_desc);

GeTensorDesc Output_in_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32);
auto Output_op_desc = std::make_shared<OpDesc>("Output", "Output");
Output_op_desc->AddInputDesc(Output_in_tensor_desc);
auto Output_node = graph->AddNode(Output_op_desc);

ge::GraphUtils::AddEdge(sqrt_node->GetOutDataAnchor(0), shape_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), Output_node->GetInDataAnchor(0));
EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS);


InferValueRangePass infer_pass;
auto ret = infer_pass.Run(shape_node);
EXPECT_EQ(ret, SUCCESS);

auto output_0_desc = shape_node->GetOpDesc()->GetOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> value_range;
output_0_desc.GetValueRange(value_range);
EXPECT_EQ(value_range.size(), 4);
std::vector<int64_t> target_value_range = {1, 100, 1, 240, 4, 4, 192, 192};
std::vector<int64_t> output_value_range;
for (auto pair : value_range) {
output_value_range.push_back(pair.first);
output_value_range.push_back(pair.second);
}
EXPECT_EQ(target_value_range, output_value_range);

auto in_0_desc = Output_node->GetOpDesc()->GetInputDesc(0);
value_range.clear();
in_0_desc.GetValueRange(value_range);
EXPECT_EQ(value_range.size(), 4);
output_value_range.clear();
for (auto pair : value_range) {
output_value_range.push_back(pair.first);
output_value_range.push_back(pair.second);
}
EXPECT_EQ(target_value_range, output_value_range);

}

class AddKernel : public Kernel {
public:
Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
std::vector<ge::GeTensorPtr> &v_output) override {
vector<int64_t> data_vec;
auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize();
auto x1_data = reinterpret_cast<const int64_t *>(input[0]->GetData().data());
auto x2_data = reinterpret_cast<const int64_t *>(input[1]->GetData().data());
for (size_t i = 0; i < data_num; i++) {
auto x_index = *(x1_data + i);
auto y_index = *(x2_data + i);
data_vec.push_back(x_index + y_index);
}
GeTensorPtr const_tensor = std::make_shared<ge::GeTensor>(input[0]->GetTensorDesc(), (uint8_t *)data_vec.data(),
data_num * sizeof(int64_t));
v_output.emplace_back(const_tensor);
return SUCCESS;
}
};
REGISTER_KERNEL(ADD, AddKernel);

TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_2_infer) {
// shape --- add --- sqrt
// constant /
INFER_VALUE_RANGE_DEFAULT_REG(Add);
INFER_VALUE_RANGE_DEFAULT_REG("Sqrt");
auto graph = std::make_shared<ComputeGraph>("test_graph");

vector<int64_t> dims_vec = {4};
vector<int64_t> data_vec = {1, 1, 1, 1};
GeTensorDesc const_tensor_desc(ge::GeShape(dims_vec), ge::FORMAT_NCHW, ge::DT_INT64);
GeTensorPtr const_tensor =
std::make_shared<ge::GeTensor>(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t));

auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant");
const_op_desc->AddOutputDesc(const_tensor_desc);
EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS);
auto const_node = graph->AddNode(const_op_desc);

GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64);
std::vector<std::pair<int64_t, int64_t>> value_range = {make_pair(1, 100), make_pair(1, 240),
make_pair(4, 4), make_pair(192, 192)};
shape_tensor_desc.SetValueRange(value_range);
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape");
shape_op_desc->AddOutputDesc(shape_tensor_desc);
auto shape_node = graph->AddNode(shape_op_desc);

GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64);
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add");
add_op_desc->AddInputDesc(shape_tensor_desc);
add_op_desc->AddInputDesc(const_tensor_desc);
add_op_desc->AddOutputDesc(add_tensor_desc);
auto add_node = graph->AddNode(add_op_desc);

auto sqrt_op_desc = std::make_shared<OpDesc>("Sqrt", "Sqrt");
sqrt_op_desc->AddInputDesc(GeTensorDesc());
auto sqrt_node = graph->AddNode(sqrt_op_desc);

ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1));
ge::GraphUtils::AddEdge(add_node->GetOutDataAnchor(0), sqrt_node->GetInDataAnchor(1));

InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(sqrt_node), SUCCESS);
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS);

auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> out_value_range;
output_0_desc.GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 4);

std::vector<int64_t> target_value_range = {2, 101, 2, 241, 5, 5, 193, 193};
std::vector<int64_t> output_value_range;
for (auto pair : out_value_range) {
output_value_range.push_back(pair.first);
output_value_range.push_back(pair.second);
}
EXPECT_EQ(target_value_range, output_value_range);
}

REG_OP(Case)
.OP_END_FACTORY_REG(Case)
IMPL_INFER_VALUE_RANGE_FUNC(Case, ValueRangeFunc){
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto output_tensor_desc = op_desc->MutableOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> in_shape_range = {make_pair(1, 2), make_pair(1, 3),
make_pair(1, 4), make_pair(1, 5)};;
output_tensor_desc->SetValueRange(in_shape_range);
return GRAPH_SUCCESS;
}
TEST_F(UtestGraphInferValueRangePass, infer_with_case_subgraph) {
auto builder = ParentGraphBuilder();
auto parent_graph = builder.GetGraph();
AddCaseSubgraph(parent_graph, 2);
auto subgraphs = parent_graph->GetAllSubgraphs();
EXPECT_EQ(subgraphs.size(), 2);

// check before subgraph
auto case_node = parent_graph->FindNode("case1");
EXPECT_NE(case_node, nullptr);
INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Case, INPUT_HAS_VALUE_RANGE, ValueRangeFunc);
InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(case_node), SUCCESS);

auto case_out_0_desc = case_node->GetOpDesc()->MutableOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> out_value_range;
case_out_0_desc->GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 4);
std::vector<int64_t> target_value_range = {1,2,1,3,1,4,1,5};
std::vector<int64_t> output_value_range_list;
for (auto pair : out_value_range) {
output_value_range_list.push_back(pair.first);
output_value_range_list.push_back(pair.second);
}
EXPECT_EQ(target_value_range, output_value_range_list);

std::vector<int64_t> target_dims_0 = {1, 1, -1, 224};
std::vector<int64_t> target_dims_1 = {1,1};
auto data_node = subgraphs[0]->FindNode("data1_0");
auto data_output_0_desc = data_node->GetOpDesc()->GetOutputDesc(0);
EXPECT_EQ(target_dims_0, data_output_0_desc.GetShape().GetDims());
data_node = subgraphs[0]->FindNode("data2_0");
auto data2_output_0_desc = data_node->GetOpDesc()->GetOutputDesc(0);
EXPECT_EQ(target_dims_1, data2_output_0_desc.GetShape().GetDims());

// check after subgraph
infer_pass.options_[kOptimizeAfterSubGraph] = "yes";
EXPECT_EQ(infer_pass.Run(case_node), SUCCESS);

std::vector<int64_t> out_target_dims = {1, -1, 1, -1};
auto case_out = case_node->GetOpDesc()->GetOutputDescPtr(0);
out_value_range.clear();
case_out->GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 2);

output_value_range_list.clear();
for (auto pair : out_value_range) {
output_value_range_list.push_back(pair.first);
output_value_range_list.push_back(pair.second);
}
EXPECT_EQ(out_target_dims, output_value_range_list);
}

/*
* data1 const1
* \ /
* while
* / \
* relu1 netoutput
*/
ut::GraphBuilder ParentWhileGraphBuilder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1", "Data", 0, 1);
std::vector<int64_t> const_shape = {1};
auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_FLOAT, const_shape);
auto while1 = builder.AddNode("while1", WHILE, 2, 2);
auto relu1 = builder.AddNode("relu1", "Relu", 1, 1);
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);

int32_t weight[1] = {1};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_FLOAT);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(const1, {tensor});
std::vector<std::pair<int64_t, int64_t>> in_range = {make_pair(1, 1), make_pair(1, 1),
make_pair(1, 224), make_pair(1, 224)};
while1->GetOpDesc()->MutableInputDesc(0)->SetValueRange(in_range);

builder.AddDataEdge(data1, 0, while1, 0);
builder.AddDataEdge(const1, 0, while1, 1);
builder.AddDataEdge(while1, 0, relu1, 0);
builder.AddDataEdge(while1, 1, netoutput, 0);
return builder;
}

/*
* data1 data2
* \ /
* switch
* | |
* \ /
* netoutput
*/
ut::GraphBuilder WhileSubgraphBuilder(string graph_name, uint32_t num) {
ut::GraphBuilder builder = ut::GraphBuilder(graph_name);

std::vector<int64_t> shape1 = {2,2};
string data1_name = "data1_" + std::to_string(num);
auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_FLOAT, shape1);
auto data1_desc = data1->GetOpDesc();
EXPECT_NE(data1_desc, nullptr);
AttrUtils::SetInt(data1_desc, "_parent_node_index", 0);

std::vector<int64_t> shape2 = {3,3};
string data2_name = "data2_" + std::to_string(num);
auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_FLOAT, shape2);
auto data2_desc = data2->GetOpDesc();
EXPECT_NE(data2_desc, nullptr);
AttrUtils::SetInt(data2_desc, "_parent_node_index", 1);

string switch_name = "switch_" + std::to_string(num);
auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2);

std::vector<int64_t> shape7 = {8,8,8,8};
string output_name = "output_" + std::to_string(num);
auto netoutput = builder.AddNode(output_name, NETOUTPUT, 2, 0, FORMAT_NCHW, DT_FLOAT, shape7);
auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0);
EXPECT_NE(input0_desc, nullptr);
AttrUtils::SetInt(input0_desc, "_parent_node_index", 0);
std::vector<std::pair<int64_t, int64_t>> range0 = {make_pair(1, -1), make_pair(1, -1)};
input0_desc->SetValueRange(range0);
auto input1_desc = netoutput->GetOpDesc()->MutableInputDesc(1);
EXPECT_NE(input1_desc, nullptr);
AttrUtils::SetInt(input1_desc, "_parent_node_index", 1);
std::vector<std::pair<int64_t, int64_t>> range1 = {make_pair(8, 8), make_pair(8, 8),make_pair(8, 8),make_pair(8, 8)};
input1_desc->SetValueRange(range1);

builder.AddDataEdge(data1, 0, switch1, 0);
builder.AddDataEdge(data2, 0, switch1, 1);
builder.AddDataEdge(switch1, 0, netoutput, 0);
builder.AddDataEdge(switch1, 1, netoutput, 1);
return builder;
}

void AddWhileSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) {
auto while_node = parent_graph->FindNode("while1");
EXPECT_NE(while_node, nullptr);

for (uint32_t i = 0; i < branch_num; ++i) {
string name = "Branch_Graph_" + std::to_string(i);

auto builder_subgraph = WhileSubgraphBuilder(name, i);
auto switch_subgraph = builder_subgraph.GetGraph();

while_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName());
while_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName());
switch_subgraph->SetParentNode(while_node);
switch_subgraph->SetParentGraph(parent_graph);
EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS);
}
}

REG_OP(While)
.OP_END_FACTORY_REG(While)
IMPL_INFER_VALUE_RANGE_FUNC(While, WhileValueRangeFunc){
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
std::vector<std::pair<int64_t, int64_t>> in_range = {make_pair(1, 2), make_pair(1, 3),
make_pair(1, 4), make_pair(1, 5)};;
for (auto i =0; i<op_desc->GetOutputsSize();++i){
auto output_tensor_desc = op_desc->MutableOutputDesc(i);
output_tensor_desc->SetValueRange(in_range);
}
return GRAPH_SUCCESS;
}
INFER_VALUE_RANGE_CUSTOM_FUNC_REG(While, INPUT_HAS_VALUE_RANGE, WhileValueRangeFunc);
TEST_F(UtestGraphInferValueRangePass, infer_with_while_subgraph) {
auto builder = ParentWhileGraphBuilder();
auto parent_graph = builder.GetGraph();
AddWhileSubgraph(parent_graph, 1);
auto subgraphs = parent_graph->GetAllSubgraphs();
EXPECT_EQ(subgraphs.size(), 1);

// check before subgraph
auto while_node = parent_graph->FindNode("while1");
EXPECT_NE(while_node, nullptr);
InferValueRangePass infer_pass;
EXPECT_EQ(infer_pass.Run(while_node), SUCCESS);

auto while_out_0_desc = while_node->GetOpDesc()->MutableOutputDesc(0);
std::vector<std::pair<int64_t, int64_t>> out_value_range;
while_out_0_desc->GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 4);
std::vector<int64_t> target_value_range = {1,2,1,3,1,4,1,5};
std::vector<int64_t> output_value_range_list;
for (auto pair : out_value_range) {
output_value_range_list.push_back(pair.first);
output_value_range_list.push_back(pair.second);
}
EXPECT_EQ(target_value_range, output_value_range_list);

std::vector<int64_t> target_dims_0 = {1, 1, 224, 224};
auto data_node = subgraphs[0]->FindNode("data1_0");
auto data_input_0_desc = data_node->GetOpDesc()->GetInputDesc(0);
EXPECT_EQ(target_dims_0, data_input_0_desc.GetShape().GetDims());

// check after subgraph
infer_pass.options_[kOptimizeAfterSubGraph] = "yes";
EXPECT_EQ(infer_pass.Run(while_node), SUCCESS);

std::vector<int64_t> out_target_dims = {1, -1, 1, -1};
auto while_out0 = while_node->GetOpDesc()->GetOutputDescPtr(0);
out_value_range.clear();
while_out0->GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 2);
output_value_range_list.clear();
for (auto pair : out_value_range) {
output_value_range_list.push_back(pair.first);
output_value_range_list.push_back(pair.second);
}
EXPECT_EQ(output_value_range_list, out_target_dims);

std::vector<int64_t> out_target_dims_1 = {8,8, 8,8, 8,8, 8,8};
auto while_out1 = while_node->GetOpDesc()->GetOutputDescPtr(1);
out_value_range.clear();
while_out1->GetValueRange(out_value_range);
EXPECT_EQ(out_value_range.size(), 4);
output_value_range_list.clear();
for (auto pair : out_value_range) {
output_value_range_list.push_back(pair.first);
output_value_range_list.push_back(pair.second);
}
EXPECT_EQ(output_value_range_list, out_target_dims_1);
}

TEST_F(UtestGraphInferValueRangePass, infer_with_while_subgraph_failed) {
auto builder = ParentWhileGraphBuilder();
auto parent_graph = builder.GetGraph();
AddWhileSubgraph(parent_graph, 2);
auto subgraphs = parent_graph->GetAllSubgraphs();
EXPECT_EQ(subgraphs.size(), 2);

auto case_node = parent_graph->FindNode("while1");
EXPECT_NE(case_node, nullptr);
InferValueRangePass infer_pass;
infer_pass.options_[kOptimizeAfterSubGraph] = "yes";
EXPECT_EQ(infer_pass.Run(case_node), GRAPH_FAILED);
}



bool IsEmptyTensor(const GeShape &ge_shape) {
bool is_empty = false;
for (const auto &dim : ge_shape.GetDims()) {
if (dim == 0) {
is_empty = true;
break;
}
}
return is_empty;
}
bool IsEmptyTensor(GeTensorDescPtr tensor_desc) {
return IsEmptyTensor(tensor_desc->MutableShape());
}
graphStatus ReshapeRangeInferAllDims(const std::vector<std::pair<int64_t, int64_t>> &x_shape_range,
const GeShape &x_shape,
const std::vector<std::pair<int64_t, int64_t>> &shape_value_range,
std::vector<std::pair<int64_t, int64_t>> &y_shape_range, GeShape &y_shape) {
// input_shape is not constant, can not get accurate shape value.
if (x_shape.GetDims() == UNKNOWN_RANK) {
return GRAPH_SUCCESS;
}

// step 1, calculate input_x range max
int64_t range_max = 1;
auto x_shape_size = x_shape.GetShapeSize();
if (x_shape_size > 0) {
// known dim, x_shape_size == range_max
range_max = x_shape_size;
} else {
// unknown dim
if (x_shape_range.empty()) {
return GRAPH_SUCCESS;
}
for (const auto &pair : x_shape_range) {
if (pair.second < 0) {
range_max = -1;
break;
}
range_max *= pair.second;
}
}

// step 2, init y shape range
auto y_rank = y_shape.GetDims().size();
auto shape_range_max = (range_max > INT32_MAX) ? INT32_MAX : range_max;
for (auto i = 0; i < y_rank; ++i) {
y_shape_range.emplace_back(std::pair<int64_t, int64_t>(1, shape_range_max));
}
if (shape_value_range.empty()) {
// no value range, can not calculate accurate shape range.
return GRAPH_SUCCESS;
}

// step 2, repair value range and check zero in value range
bool has_zero_in_value_range = false;
std::vector<std::pair<int64_t, int64_t>> value_range = shape_value_range;
for (auto &pair : value_range) {
if (pair.first < 0) {
pair.first = 1;
}
if (pair.second < 0) {
pair.second = -1;
}
if (pair.first == 0) {
has_zero_in_value_range = true;
}
}

// step 3, deal with empty tensor. if no value range cannot infer empty tensor.
if (IsEmptyTensor(x_shape)) {
if (range_max != 0) {
return GRAPH_FAILED;
}
if (!has_zero_in_value_range) {
return GRAPH_FAILED;
}
std::vector<int64_t> y_dims = y_shape.GetDims();
for (auto i = 0; i < y_rank; ++i) {
if (value_range[i].first == value_range[i].second) {
y_dims[i] = value_range[i].first;
}
}
y_shape_range = value_range;
y_shape = GeShape(y_dims);
return GRAPH_SUCCESS;
}

// step 4, calculate accurate dims and shape_range
std::vector<int64_t> y_dims = y_shape.GetDims();
for (auto i = 0; i < y_rank; ++i) {
if (value_range[i].first == value_range[i].second) {
y_dims[i] = value_range[i].first;
y_shape_range[i] = std::pair<int64_t, int64_t>(y_dims[i], y_dims[i]);
} else {
if (range_max == -1) {
// while range_max = -1, range_max && value_range[i].second is always value_range[i].second;
y_shape_range[i] = std::pair<int64_t, int64_t>(value_range[i].first, value_range[i].second);
continue;
}
int64_t other_dims_range_lower_boundary = 1;
for (auto j = 0; j < y_rank; ++j) {
if (i == j) {
continue;
}
other_dims_range_lower_boundary *= value_range[j].first;

}
int64_t cur_dim_range_max = static_cast<int64_t>(
(static_cast<double>(range_max) + other_dims_range_lower_boundary - 1) / other_dims_range_lower_boundary);
if (value_range[i].second == -1) {
cur_dim_range_max = (cur_dim_range_max < INT32_MAX) ? cur_dim_range_max : INT32_MAX;
y_shape_range[i] = std::pair<int64_t, int64_t>(value_range[i].first, cur_dim_range_max);
continue;
}
cur_dim_range_max = (cur_dim_range_max < value_range[i].second) ? cur_dim_range_max : value_range[i].second;
cur_dim_range_max = (cur_dim_range_max < INT32_MAX) ? cur_dim_range_max : INT32_MAX;
y_shape_range[i] = std::pair<int64_t, int64_t>(value_range[i].first, cur_dim_range_max);
}
}
y_shape = GeShape(y_dims);
return GRAPH_SUCCESS;
}

TEST_F(UtestGraphInferValueRangePass, reshape_infer_func_test_1) {
auto rank = 4;
std::vector<std::pair<int64_t, int64_t>> x_shape_range = {make_pair(1, 100), make_pair(1, 400)};
GeShape x_shape = GeShape(std::vector<int64_t>(2, UNKNOWN_DIM));
std::vector<std::pair<int64_t, int64_t>> shape_value_range = {make_pair(100, -1), make_pair(-1, -10),
make_pair(1, 20), make_pair(10, 10)};
std::vector<std::pair<int64_t, int64_t>> y_shape_range;
GeShape y_shape = GeShape(std::vector<int64_t>(rank, UNKNOWN_DIM));
auto ret = ReshapeRangeInferAllDims(x_shape_range, x_shape, shape_value_range, y_shape_range, y_shape);
EXPECT_EQ(ret, GRAPH_SUCCESS);

std::vector<int64_t> target_y_shape_dims = {-1, -1, -1, 10};
EXPECT_EQ(y_shape.GetDims(), target_y_shape_dims);

std::vector<int64_t> target_y_shape_range = {100, 4000, 1, 40, 1, 20, 10, 10};
std::vector<int64_t> output_shape_range;
for (auto pair : y_shape_range) {
output_shape_range.push_back(pair.first);
output_shape_range.push_back(pair.second);
}
EXPECT_EQ(output_shape_range, target_y_shape_range);
}

TEST_F(UtestGraphInferValueRangePass, reshape_infer_func_test_2) {
auto rank = 4;
std::vector<std::pair<int64_t, int64_t>> x_shape_range = {make_pair(1, 100), make_pair(1, 400), make_pair(-1, -1)};
GeShape x_shape = GeShape(std::vector<int64_t>(3, UNKNOWN_DIM));
std::vector<std::pair<int64_t, int64_t>> shape_value_range = {make_pair(100, -1), make_pair(1, -10),
make_pair(1, 20), make_pair(10, 10)};
std::vector<std::pair<int64_t, int64_t>> y_shape_range;
GeShape y_shape = GeShape(std::vector<int64_t>(rank, UNKNOWN_DIM));
auto ret = ReshapeRangeInferAllDims(x_shape_range, x_shape, shape_value_range, y_shape_range, y_shape);
EXPECT_EQ(ret, GRAPH_SUCCESS);

std::vector<int64_t> target_y_shape_dims = {-1, -1, -1, 10};
EXPECT_EQ(y_shape.GetDims(), target_y_shape_dims);

std::vector<int64_t> target_y_shape_range = {100, -1, 1, -1, 1, 20, 10, 10};
std::vector<int64_t> output_shape_range;
for (auto pair : y_shape_range) {
output_shape_range.push_back(pair.first);
output_shape_range.push_back(pair.second);
}
EXPECT_EQ(output_shape_range, target_y_shape_range);
}

TEST_F(UtestGraphInferValueRangePass, reshape_infer_func_test_3) {
auto rank = 4;
std::vector<std::pair<int64_t, int64_t>> x_shape_range = {};
GeShape x_shape = GeShape(std::vector<int64_t>(3, 100));
std::vector<std::pair<int64_t, int64_t>> shape_value_range = {make_pair(100, -1), make_pair(1, -10),
make_pair(1, 20), make_pair(10, 10)};
std::vector<std::pair<int64_t, int64_t>> y_shape_range;
GeShape y_shape = GeShape(std::vector<int64_t>(rank, UNKNOWN_DIM));
auto ret = ReshapeRangeInferAllDims(x_shape_range, x_shape, shape_value_range, y_shape_range, y_shape);
EXPECT_EQ(ret, GRAPH_SUCCESS);

std::vector<int64_t> target_y_shape_dims = {-1, -1, -1, 10};
EXPECT_EQ(y_shape.GetDims(), target_y_shape_dims);

std::vector<int64_t> target_y_shape_range = {100, 100000, 1, 1000, 1, 20, 10, 10};
std::vector<int64_t> output_shape_range;
for (auto pair : y_shape_range) {
output_shape_range.push_back(pair.first);
output_shape_range.push_back(pair.second);
}
EXPECT_EQ(output_shape_range, target_y_shape_range);
}
TEST_F(UtestGraphInferValueRangePass, reshape_infer_func_test_4) {
auto rank = 4;
std::vector<std::pair<int64_t, int64_t>> x_shape_range = {make_pair(1, 100), make_pair(0, 0)};
GeShape x_shape = GeShape({-1, 0});
std::vector<std::pair<int64_t, int64_t>> shape_value_range = {make_pair(0, 0), make_pair(-1, -10),
make_pair(10, 20), make_pair(100, 100)};
std::vector<std::pair<int64_t, int64_t>> y_shape_range;
GeShape y_shape = GeShape(std::vector<int64_t>(rank, UNKNOWN_DIM));
auto ret = ReshapeRangeInferAllDims(x_shape_range, x_shape, shape_value_range, y_shape_range, y_shape);
EXPECT_EQ(ret, GRAPH_SUCCESS);

std::vector<int64_t> target_y_shape_dims = {0, -1, -1, 100};
EXPECT_EQ(y_shape.GetDims(), target_y_shape_dims);

std::vector<int64_t> target_y_shape_range = {0, 0, 1, -1, 10, 20, 100, 100};
std::vector<int64_t> output_shape_range;
for (auto pair : y_shape_range) {
output_shape_range.push_back(pair.first);
output_shape_range.push_back(pair.second);
}
EXPECT_EQ(output_shape_range, target_y_shape_range);
}

} // namespace ge

Loading…
Cancel
Save