Merge pull request !1860 from 王强/mastertags/v1.5.1
@@ -298,7 +298,9 @@ set(TRAIN_SRC_LIST | |||||
"graph/passes/hccl_continuous_memcpy_pass.cc" | "graph/passes/hccl_continuous_memcpy_pass.cc" | ||||
"graph/passes/identity_pass.cc" | "graph/passes/identity_pass.cc" | ||||
"graph/passes/ref_identity_delete_op_pass.cc" | "graph/passes/ref_identity_delete_op_pass.cc" | ||||
"graph/passes/infer_base_pass.cc" | |||||
"graph/passes/infershape_pass.cc" | "graph/passes/infershape_pass.cc" | ||||
"graph/passes/infer_value_range_pass.cc" | |||||
"graph/passes/iterator_op_pass.cc" | "graph/passes/iterator_op_pass.cc" | ||||
"graph/passes/link_gen_mask_nodes_pass.cc" | "graph/passes/link_gen_mask_nodes_pass.cc" | ||||
"graph/passes/merge_pass.cc" | "graph/passes/merge_pass.cc" | ||||
@@ -547,7 +549,9 @@ set(INFER_SRC_LIST | |||||
"graph/passes/shape_operate_op_remove_pass.cc" | "graph/passes/shape_operate_op_remove_pass.cc" | ||||
"graph/passes/assert_pass.cc" | "graph/passes/assert_pass.cc" | ||||
"graph/passes/dropout_pass.cc" | "graph/passes/dropout_pass.cc" | ||||
"graph/passes/infer_base_pass.cc" | |||||
"graph/passes/infershape_pass.cc" | "graph/passes/infershape_pass.cc" | ||||
"graph/passes/infer_value_range_pass.cc" | |||||
"graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
"graph/passes/permute_pass.cc" | "graph/passes/permute_pass.cc" | ||||
"graph/passes/ctrl_edge_transfer_pass.cc" | "graph/passes/ctrl_edge_transfer_pass.cc" | ||||
@@ -49,6 +49,25 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const s | |||||
return JoinToString(shape); | return JoinToString(shape); | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY | |||||
std::string RangeToString(const std::vector<std::pair<int64_t, int64_t>> &ranges) { | |||||
bool first = true; | |||||
std::stringstream ss; | |||||
ss << "["; | |||||
for (const auto &range : ranges) { | |||||
if (first) { | |||||
first = false; | |||||
} else { | |||||
ss << ","; | |||||
} | |||||
ss << "{"; | |||||
ss << range.first << "," << range.second; | |||||
ss << "}"; | |||||
} | |||||
ss << "]"; | |||||
return ss.str(); | |||||
} | |||||
int64_t GetItemNumByShape(const std::vector<int64_t> &shape) { | int64_t GetItemNumByShape(const std::vector<int64_t> &shape) { | ||||
int64_t num = 1; | int64_t num = 1; | ||||
for (auto dim : shape) { | for (auto dim : shape) { | ||||
@@ -54,6 +54,8 @@ std::string ShapeToString(const GeShape &shape); | |||||
std::string ShapeToString(const std::vector<int64_t> &shape); | std::string ShapeToString(const std::vector<int64_t> &shape); | ||||
std::string RangeToString(const std::vector<std::pair<int64_t, int64_t>> &ranges); | |||||
int64_t GetItemNumByShape(const std::vector<int64_t> &shape); | int64_t GetItemNumByShape(const std::vector<int64_t> &shape); | ||||
bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims); | bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims); | ||||
@@ -20,17 +20,23 @@ | |||||
#include "external/graph/operator_factory.h" | #include "external/graph/operator_factory.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "ge_local_engine/engine/host_cpu_engine.h" | |||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
namespace ge { | namespace ge { | ||||
const int64_t kStartCallNum = 1; | const int64_t kStartCallNum = 1; | ||||
const std::string kKernelLibName = "aicpu_tf_kernel"; | const std::string kKernelLibName = "aicpu_tf_kernel"; | ||||
// tf_kernel.json opsFlag config | |||||
const std::string kOpsFlagClose = "0"; | const std::string kOpsFlagClose = "0"; | ||||
Status RunOpKernelWithCheck(NodePtr &node, | |||||
const vector<ConstGeTensorPtr> &inputs, | |||||
std::vector<GeTensorPtr> &outputs) { | |||||
const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { | |||||
return statistic_of_ge_constant_folding_; | |||||
} | |||||
const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { | |||||
return statistic_of_op_constant_folding_; | |||||
} | |||||
Status ConstantFoldingPass::RunOpKernelWithCheck(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, | |||||
std::vector<GeTensorPtr> &outputs) { | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | ||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized or is finalized."); | GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized or is finalized."); | ||||
@@ -47,15 +53,13 @@ Status RunOpKernelWithCheck(NodePtr &node, | |||||
if (ops_flag == kOpsFlagClose) { | if (ops_flag == kOpsFlagClose) { | ||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
return FoldingPass::RunOpKernel(node, inputs, outputs); | |||||
return RunOpKernel(node, inputs, outputs); | |||||
} | } | ||||
const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { | |||||
return statistic_of_ge_constant_folding_; | |||||
} | |||||
const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { | |||||
return statistic_of_op_constant_folding_; | |||||
Status ConstantFoldingPass::RunOpKernel(NodePtr &node, | |||||
const vector<ConstGeTensorPtr> &inputs, | |||||
std::vector<GeTensorPtr> &outputs) { | |||||
return HostCpuEngine::GetInstance().Run(node, inputs, outputs); | |||||
} | } | ||||
Status ConstantFoldingPass::Run(ge::NodePtr &node) { | Status ConstantFoldingPass::Run(ge::NodePtr &node) { | ||||
@@ -28,6 +28,11 @@ class ConstantFoldingPass : public FoldingPass { | |||||
Status Run(ge::NodePtr &node) override; | Status Run(ge::NodePtr &node) override; | ||||
const std::map<std::string, std::pair<std::uint64_t, uint64_t>> &GetGeConstantFoldingPerfStatistic() const; | const std::map<std::string, std::pair<std::uint64_t, uint64_t>> &GetGeConstantFoldingPerfStatistic() const; | ||||
const std::map<std::string, std::pair<std::uint64_t, uint64_t>> &GetOpConstantFoldingPerfStatistic() const; | const std::map<std::string, std::pair<std::uint64_t, uint64_t>> &GetOpConstantFoldingPerfStatistic() const; | ||||
static Status RunOpKernel(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, vector<GeTensorPtr> &outputs); | |||||
static Status RunOpKernelWithCheck(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, | |||||
std::vector<GeTensorPtr> &outputs); | |||||
private: | private: | ||||
std::map<std::string, std::pair<std::uint64_t, uint64_t>> statistic_of_op_constant_folding_; | std::map<std::string, std::pair<std::uint64_t, uint64_t>> statistic_of_op_constant_folding_; | ||||
std::map<std::string, std::pair<std::uint64_t, uint64_t>> statistic_of_ge_constant_folding_; | std::map<std::string, std::pair<std::uint64_t, uint64_t>> statistic_of_ge_constant_folding_; | ||||
@@ -28,8 +28,6 @@ | |||||
#include "inc/kernel.h" | #include "inc/kernel.h" | ||||
#include "inc/kernel_factory.h" | #include "inc/kernel_factory.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "ge_local_engine/engine/host_cpu_engine.h" | |||||
namespace ge { | namespace ge { | ||||
namespace folding_pass { | namespace folding_pass { | ||||
@@ -123,12 +121,6 @@ NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tens | |||||
} | } | ||||
} // namespace | } // namespace | ||||
Status FoldingPass::RunOpKernel(NodePtr &node, | |||||
const vector<ConstGeTensorPtr> &inputs, | |||||
std::vector<GeTensorPtr> &outputs) { | |||||
return HostCpuEngine::GetInstance().Run(node, inputs, outputs); | |||||
} | |||||
Status FoldingPass::Folding(NodePtr &node, vector<GeTensorPtr> &outputs) { | Status FoldingPass::Folding(NodePtr &node, vector<GeTensorPtr> &outputs) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("begin folding node:%s", node->GetName().c_str()); | GELOGD("begin folding node:%s", node->GetName().c_str()); | ||||
@@ -34,8 +34,6 @@ bool IsNoNeedConstantFolding(const NodePtr &node); | |||||
using IndexsToAnchors = std::map<int, std::vector<InDataAnchorPtr>>; | using IndexsToAnchors = std::map<int, std::vector<InDataAnchorPtr>>; | ||||
class FoldingPass : public BaseNodePass { | class FoldingPass : public BaseNodePass { | ||||
public: | |||||
static Status RunOpKernel(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, vector<GeTensorPtr> &outputs); | |||||
protected: | protected: | ||||
Status Folding(NodePtr &node, vector<GeTensorPtr> &outputs); | Status Folding(NodePtr &node, vector<GeTensorPtr> &outputs); | ||||
private: | private: | ||||
@@ -0,0 +1,386 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "infer_base_pass.h" | |||||
#include "common/ge/ge_util.h" | |||||
#include "common/util/error_manager/error_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "framework/common/util.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
namespace ge { | |||||
namespace { | |||||
graphStatus FindValidSubgraphNetoutput(const ConstNodePtr &node, const ComputeGraphPtr &sub_graph, NodePtr &netoutput) { | |||||
auto sub_nodes = sub_graph->GetDirectNode(); | |||||
for (size_t i = sub_nodes.size(); i > 0; --i) { | |||||
auto sub_node = sub_nodes.at(i - 1); | |||||
if (sub_node->GetType() == NETOUTPUT) { | |||||
if (sub_node == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "NetOutput node is null in subgraph %s, parent node %s.", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Check][Param] NetOutput node is null on sub graph %s, parent node %s", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto sub_node_opdesc = sub_node->GetOpDesc(); | |||||
if (sub_node_opdesc == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Invalid NetOutput node in subgraph %s, parent node %s, no OpDesc on it", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Check][Param] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
netoutput = sub_node; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} | |||||
REPORT_INNER_ERROR("E19999", "Can not find the NetOutput node in subgraph %s, parent node %s", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Check][Param] Can not find the NetOutput node in subgraph %s, parent node %s", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} // namespace | |||||
Status InferBasePass::Run(NodePtr &node) { | |||||
GE_CHECK_NOTNULL(node); | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
bool need_infer = NeedInfer(node); | |||||
if (!need_infer) { | |||||
GELOGD("Node %s does not need to infer.", node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
std::set<NodePtr> changed_nodes; | |||||
auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(ret, "Infer and update for node %s failed! ret: %u", node->GetName().c_str(), ret); | |||||
return GRAPH_FAILED; | |||||
} | |||||
AddChangedNodesImmediateRepass(changed_nodes); | |||||
return SUCCESS; | |||||
} | |||||
bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } | |||||
void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) { | |||||
for (const auto &node_ele : changed_nodes) { | |||||
AddImmediateRePassNode(node_ele); | |||||
} | |||||
} | |||||
graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | |||||
graphStatus ret; | |||||
if (ContainsSubgraph(node)) { | |||||
if (before_subgraph) { | |||||
ret = UpdateTensorDescToSubgraphData(node); | |||||
} else { | |||||
ret = UpdateTensorDescToParentNodeOutput(node); | |||||
} | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(ret, "Update tensor desc failed between parent node %s and subgraphs. ret: %u", node->GetName().c_str(), | |||||
ret); | |||||
return ret; | |||||
} | |||||
} | |||||
PrintInOutTensors(node, "before_infer"); | |||||
ret = Infer(node); | |||||
PrintInOutTensors(node, "after_infer"); | |||||
if (ret == GRAPH_NODE_NEED_REPASS) { | |||||
// if a node need re_pass, it is not necessary to update peer node input. | |||||
changed_nodes.insert(node); | |||||
return GRAPH_SUCCESS; | |||||
} else if (ret != GRAPH_SUCCESS && ret != GRAPH_NOT_CHANGED) { | |||||
GELOGE(ret, "Infer failed for node %s, ret: %u", node->GetName().c_str(), ret); | |||||
return ret; | |||||
} | |||||
ret = UpdateTensorDescToPeerInputs(node, changed_nodes); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(ret, "Node %s updates tensor desc to peer input nodes failed! ret: %u", node->GetName().c_str(), ret); | |||||
} | |||||
GELOGD("Node %s infer and update succeeded .", node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
bool InferBasePass::ContainsSubgraph(const NodePtr &node) { | |||||
auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); | |||||
return !sub_graph_names.empty(); | |||||
} | |||||
graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | |||||
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||||
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc(); | |||||
if (peer_anchor_opdesc == nullptr) { | |||||
continue; | |||||
} | |||||
auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx()); | |||||
if (peer_input_desc == nullptr) { | |||||
continue; | |||||
} | |||||
bool changed = false; | |||||
auto ret = UpdateTensorDesc(output_tensor, peer_input_desc, changed); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Update peer input desc failed, node %s.", node->GetName().c_str()); | |||||
GELOGE(ret, "Update peer input desc failed, node %s.", node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
if (changed) { | |||||
changed_nodes.insert(peer_anchor->GetOwnerNode()); | |||||
GELOGD("Node %s update peer node succeeded, peer node %s is changed.", node->GetName().c_str(), | |||||
peer_anchor->GetOwnerNode()->GetName().c_str()); | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
std::vector<ComputeGraphPtr> InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) { | |||||
std::vector<ComputeGraphPtr> cur_node_subgraph; | |||||
auto op_desc = node->GetOpDesc(); | |||||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
if (sub_graph_names.empty()) { | |||||
return cur_node_subgraph; | |||||
} | |||||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
for (const auto &name : sub_graph_names) { | |||||
if (name.empty()) { | |||||
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
auto sub_graph = root_graph->GetSubgraph(name); | |||||
if (sub_graph == nullptr) { | |||||
GELOGW("The subgrpah %s for node %s is null.", name.c_str(), node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
cur_node_subgraph.emplace_back(sub_graph); | |||||
} | |||||
return cur_node_subgraph; | |||||
} | |||||
graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
for (const auto &sub_graph : GetCurNodeSubgraphs(node)) { | |||||
for (const auto &node_sub : sub_graph->GetDirectNode()) { | |||||
if (node_sub->GetType() != DATA) { | |||||
continue; | |||||
} | |||||
auto data_opdesc = node_sub->GetOpDesc(); | |||||
if (data_opdesc == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
int ref_i; | |||||
if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
GELOGD("Subgraph Data node ref_index is %d, parent node is %s.", ref_i, node->GetName().c_str()); | |||||
// In multi-batch, data shape of subgraph is different, no need to refresh. | |||||
if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { | |||||
GELOGD("While updating subgraph data node, ignore node %s which is created by multi-dims", | |||||
data_opdesc->GetName().c_str()); | |||||
continue; | |||||
} | |||||
auto input_desc = op_desc->MutableInputDesc(ref_i); | |||||
if (input_desc == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", | |||||
"The ref index(%d) on the data %s on the sub graph %s " | |||||
"parent node %s are incompatible, inputs num %u", | |||||
ref_i, node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str(), | |||||
node->GetAllInDataAnchorsSize()); | |||||
GELOGE(GRAPH_FAILED, | |||||
"[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s " | |||||
"parent node %s are incompatible, inputs num %u", | |||||
ref_i, node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str(), | |||||
node->GetAllInDataAnchorsSize()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), | |||||
node->GetName().c_str()); | |||||
bool has_tensor_desc_changed = false; | |||||
auto data_input_td = data_opdesc->MutableInputDesc(0); | |||||
auto ret = UpdateTensorDesc(input_desc, data_input_td, has_tensor_desc_changed); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s", | |||||
node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", | |||||
node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
auto data_output_td = data_opdesc->MutableOutputDesc(0); | |||||
ret = UpdateTensorDesc(input_desc, data_output_td, has_tensor_desc_changed); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s", | |||||
node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed", | |||||
node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
GELOGD("Parent node %s update subgraph data %s input and output succeed.", node->GetName().c_str(), | |||||
data_opdesc->GetName().c_str()); | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus InferBasePass::UpdateTensorDescToParentNodeOutput(NodePtr &node) { | |||||
std::vector<std::vector<GeTensorDescPtr>> ref_out_tensors(node->GetAllOutDataAnchorsSize()); | |||||
for (const auto &sub_graph : GetCurNodeSubgraphs(node)) { | |||||
NodePtr netoutput; | |||||
auto ret = FindValidSubgraphNetoutput(node, sub_graph, netoutput); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
return ret; | |||||
} | |||||
auto netoutput_opdesc = netoutput->GetOpDesc(); | |||||
for (auto &netoutput_in_anchor : netoutput->GetAllInDataAnchors()) { | |||||
auto netoutput_in_desc = netoutput_opdesc->MutableInputDesc(netoutput_in_anchor->GetIdx()); | |||||
if (netoutput_in_desc == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", | |||||
"Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str(), netoutput_in_anchor->GetIdx()); | |||||
GELOGE(GRAPH_FAILED, | |||||
"[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", | |||||
sub_graph->GetName().c_str(), node->GetName().c_str(), netoutput_in_anchor->GetIdx()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", netoutput_in_anchor->GetIdx(), | |||||
netoutput_in_desc->GetShape().GetDimNum()); | |||||
int ref_i; | |||||
if (!AttrUtils::GetInt(netoutput_in_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
// if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||||
continue; | |||||
} | |||||
GELOGI("Parent node index of edge desc is %d", ref_i); | |||||
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) { | |||||
REPORT_INNER_ERROR("E19999", | |||||
"Invalid ref_index %d of parent node %s, ref_index should less than %u.", ref_i, | |||||
node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||||
GELOGE(GRAPH_FAILED, | |||||
"[Get][Ref_index] Invalid ref_index %d of parent node %s, ref_index should less than %u.", ref_i, | |||||
node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
ref_out_tensors[ref_i].emplace_back(netoutput_in_desc); | |||||
} | |||||
} | |||||
return UpdateParentNodeContainsSubgraphs(node, ref_out_tensors); | |||||
} | |||||
graphStatus InferBasePass::UpdateParentNodeContainsSubgraphs( | |||||
NodePtr &node, const std::vector<std::vector<GeTensorDescPtr>> &ref_out_tensors) { | |||||
for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||||
if (ref_out_tensors[i].empty()) { | |||||
REPORT_CALL_ERROR("E19999", "Parent node %s ref_index %zu subgraph output tensor list is empty.", | |||||
node->GetName().c_str(), i); | |||||
GELOGE(GRAPH_FAILED, "[Param][check] Parent node %s ref_index %zu subgraph output tensor list is empty.", | |||||
node->GetName().c_str(), i); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto node_op_desc = node->GetOpDesc(); | |||||
auto node_output_td = node_op_desc->MutableOutputDesc(i); | |||||
if (node_output_td == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "Node %s output %zu tensor desc is null.", node->GetName().c_str(), i); | |||||
GELOGE(GRAPH_FAILED, "[Param][check] Node %s output %zu tensor desc is null.", node->GetName().c_str(), i); | |||||
return GRAPH_FAILED; | |||||
} | |||||
graphStatus ret; | |||||
if (node_op_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { | |||||
ret = UpdateOutputFromSubgraphsForMultiDims(ref_out_tensors[i], node_output_td); | |||||
} else { | |||||
ret = UpdateOutputFromSubgraphs(ref_out_tensors[i], node_output_td); | |||||
} | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Node %s update output %zu tensor desc failed. ret: %u", node->GetName().c_str(), i, | |||||
ret); | |||||
GELOGE(GRAPH_FAILED, "[Param][check] Node %s update output %zu tensor desc failed. ret: %u", | |||||
node->GetName().c_str(), i, ret); | |||||
return ret; | |||||
} | |||||
GELOGD("Parent node %s successfully updated the output tensors from subgraphs.", node->GetName().c_str()); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
void InferBasePass::PrintInOutTensors(const NodePtr &node, const std::string &phase) { | |||||
if (!IsLogEnable(GE, DLOG_DEBUG)) { | |||||
return; | |||||
} | |||||
if (node == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | |||||
GELOGE(GRAPH_FAILED, "[Check][Param] node is null"); | |||||
return; | |||||
} | |||||
ge::OpDescPtr op_desc = node->GetOpDesc(); | |||||
GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "Node has no opdesc, check invalid"); | |||||
GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return ); | |||||
std::stringstream ss; | |||||
ss << "{"; | |||||
int32_t in_idx = 0; | |||||
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
if (input_desc == nullptr) { | |||||
in_idx++; | |||||
continue; | |||||
} | |||||
if (in_idx > 0) { | |||||
ss << " "; | |||||
} | |||||
ss << "input_" << in_idx << " tensor: "; | |||||
ss << SerialTensorInfo(input_desc); | |||||
in_idx++; | |||||
} | |||||
int32_t out_idx = 0; | |||||
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
if (output_desc == nullptr) { | |||||
out_idx++; | |||||
continue; | |||||
} | |||||
ss << " "; | |||||
ss << "output_" << out_idx << " tensor: "; | |||||
ss << SerialTensorInfo(output_desc); | |||||
out_idx++; | |||||
} | |||||
ss << "}"; | |||||
GELOGD("Infer tensor dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str()); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,65 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_INFER_BASE_PASS_H_ | |||||
#define GE_GRAPH_PASSES_INFER_BASE_PASS_H_ | |||||
#include "graph/passes/base_pass.h" | |||||
namespace ge { | |||||
class InferBasePass : public BaseNodePass { | |||||
public: | |||||
Status Run(NodePtr &node) override; | |||||
graphStatus InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes); | |||||
void PrintInOutTensors(const NodePtr &node, const std::string &phase); | |||||
protected: | |||||
virtual std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const = 0; | |||||
virtual bool NeedInfer(const NodePtr &node) const; | |||||
virtual graphStatus Infer(NodePtr &node) = 0; | |||||
/** | |||||
* Update the output TensorDesc by src TensorDesc. This will be called when updating peer node input desc. | |||||
* @param src, input TensorDesc | |||||
* @param dst, output TensorDesc to be updated | |||||
* @return | |||||
*/ | |||||
virtual graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0; | |||||
/** | |||||
* Update the output TensorDesc for nodes which contain subgraphs. | |||||
* In dynamic multi-dims/batch/images size scene, the update process maybe different, | |||||
* in which case, the `InferBasePass` will call method `UpdateOutputFromSubgraphsForMultiDims` instead. | |||||
* @param src, input TensorDesc from NetOutput nodes in all subgraphs | |||||
* @param dst, output TensorDesc to be updated | |||||
* @return | |||||
*/ | |||||
virtual graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, | |||||
GeTensorDescPtr &dst) = 0; | |||||
virtual graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, | |||||
GeTensorDescPtr &dst) = 0; | |||||
private: | |||||
void AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes); | |||||
bool ContainsSubgraph(const NodePtr &node); | |||||
std::vector<ComputeGraphPtr> GetCurNodeSubgraphs(const NodePtr &node); | |||||
graphStatus UpdateTensorDescToSubgraphData(NodePtr &node); | |||||
graphStatus UpdateTensorDescToParentNodeOutput(NodePtr &node); | |||||
graphStatus UpdateParentNodeContainsSubgraphs(NodePtr &node, | |||||
const std::vector<std::vector<GeTensorDescPtr>> &ref_out_tensors); | |||||
graphStatus UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_INFER_BASE_PASS_H_ |
@@ -0,0 +1,500 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/infer_value_range_pass.h" | |||||
#include "common/formats/utils/formats_trans_utils.h" | |||||
#include "common/util/error_manager/error_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/operator_factory_impl.h" | |||||
#include "graph/passes/constant_folding_pass.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "common/ge/ge_util.h" | |||||
using std::unique_ptr; | |||||
namespace ge { | |||||
namespace { | |||||
#define GET_DATA_BY_DTYPE(DTYPE, TYPE) \ | |||||
case (DTYPE): \ | |||||
ConstructValueRange<TYPE>(lower_boundary_tensor, upper_boundary_tensor, output_tensor_value_range); \ | |||||
break; | |||||
void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
(void)desc->GetShapeRange(shape_range); | |||||
desc_str += formats::RangeToString(shape_range); | |||||
shape_range.clear(); | |||||
(void)desc->GetOriginShapeRange(shape_range); | |||||
desc_str += ","; | |||||
desc_str += formats::RangeToString(shape_range); | |||||
shape_range.clear(); | |||||
} | |||||
Status RunCpuKernelForValueRange(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, | |||||
std::vector<GeTensorPtr> &outputs) { | |||||
// RunOpKernelWithCheck, RunOpKernel for test | |||||
auto ret = ConstantFoldingPass::RunOpKernel(node, inputs, outputs); | |||||
if (ret != SUCCESS) { | |||||
auto op_kernel = folding_pass::GetKernelByType(node); | |||||
if (op_kernel == nullptr) { | |||||
GELOGW("Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
return NOT_CHANGED; | |||||
} | |||||
ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Calculate value range failed, node %s run cpu kernel failed.", node->GetName().c_str()); | |||||
return NOT_CHANGED; | |||||
} | |||||
} | |||||
GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace | |||||
graphStatus InferValueRangePass::Infer(NodePtr &node) { | |||||
auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); | |||||
// Use registered func to calculate value range | |||||
if (!infer_value_range_param.use_cpu_kernel) { | |||||
if (infer_value_range_param.infer_value_func == nullptr) { | |||||
GELOGW("The registered func of node %s to infer value range is nullptr.", node->GetName().c_str()); | |||||
return GRAPH_NOT_CHANGED; | |||||
} | |||||
Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||||
auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGW("Node %s call infer value range func failed, ret: %u.", node->GetName().c_str(), ret); | |||||
return GRAPH_NOT_CHANGED; | |||||
} | |||||
GELOGD("Node %s infer value range func succeed by registered func.", node->GetName().c_str()); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
// if input value range has -1, cpu kernel cannot calculate correctly, so set {1:-1} | |||||
if (InputHasUnknownValueRange(node)) { | |||||
GELOGI("Node %s has unknown value range in input tensors, set value range {1:-1}, and skip cpu kernel.", | |||||
node->GetName().c_str()); | |||||
return GenerateWorstValueRange(node); | |||||
} | |||||
// Use CPU kernel func to calculate value range | |||||
auto ret = ConstructInputAndInferValueRange(node); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGW("Use CPU kernel to calculate value range failed. node: %s, ret: %u", node->GetName().c_str(), ret); | |||||
return GRAPH_NOT_CHANGED; | |||||
} | |||||
GELOGD("Node %s infer value range func succeed by running cpu kernel.", node->GetName().c_str()); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
std::string InferValueRangePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { | |||||
std::stringstream ss; | |||||
ss << "["; | |||||
ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; | |||||
string range_str; | |||||
SerialShapeRange(tensor_desc, range_str); | |||||
ss << "(shape_range:" << range_str << "),"; | |||||
std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
(void)tensor_desc->GetValueRange(value_range); | |||||
string value_range_str = formats::RangeToString(value_range); | |||||
ss << "(value_range:" << value_range_str << ")]"; | |||||
return ss.str(); | |||||
} | |||||
bool InferValueRangePass::NeedInfer(const NodePtr &node) const { | |||||
auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); | |||||
if (!infer_value_range_param.is_initialized) { | |||||
GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.", | |||||
node->GetName().c_str()); | |||||
return false; | |||||
} | |||||
if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) { | |||||
// Only do infer for node that all inputs are dynamic, such as shape | |||||
if (InputIsDynamic(node)) { | |||||
return true; | |||||
} | |||||
GELOGD("Node %s register func to infer value range and when_call is INPUT_IS_DYNAMIC, but check input failed.", | |||||
node->GetName().c_str()); | |||||
} else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) { | |||||
// Only do infer for node that all inputs have value_range or node type of inputs is constant/const | |||||
if (InputIsConstOrHasValueRange(node)) { | |||||
return true; | |||||
} | |||||
GELOGD("Node %s register func to infer value range and when_call is INPUT_HAS_VALUE_RANGE, but check input failed.", | |||||
node->GetName().c_str()); | |||||
} | |||||
GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str()); | |||||
return false; | |||||
} | |||||
bool InferValueRangePass::InputIsDynamic(const NodePtr &node) const{ | |||||
bool input_is_dynamic = false; | |||||
auto cur_op_desc = node->GetOpDesc(); | |||||
for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { | |||||
auto dims = input_desc->GetShape().GetDims(); | |||||
for (auto dim : dims) { | |||||
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||||
input_is_dynamic = true; | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
return input_is_dynamic; | |||||
} | |||||
bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) const { | |||||
bool input_is_const_or_has_value_range = true; | |||||
auto cur_op_desc = node->GetOpDesc(); | |||||
auto in_data_anchors = node->GetAllInDataAnchors(); | |||||
for (size_t i = 0; i < in_data_anchors.size(); ++i) { | |||||
auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto peer_node = peer_out_anchor->GetOwnerNode(); | |||||
if (peer_node == nullptr || peer_node->GetOpDesc() == nullptr) { | |||||
continue; | |||||
} | |||||
if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { | |||||
continue; | |||||
} | |||||
const auto &input_desc = cur_op_desc->GetInputDesc(i); | |||||
std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
(void)input_desc.GetValueRange(value_range); | |||||
if (value_range.empty()) { | |||||
GELOGD("Node %s input %zu does not have value range, skip infer_value_range_pass for current node.", | |||||
node->GetName().c_str(), i); | |||||
input_is_const_or_has_value_range = false; | |||||
break; | |||||
} | |||||
} | |||||
return input_is_const_or_has_value_range; | |||||
} | |||||
bool InferValueRangePass::InputHasUnknownValueRange(const NodePtr &node) const { | |||||
bool has_unknown_value_range = false; | |||||
auto cur_op_desc = node->GetOpDesc(); | |||||
for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { | |||||
std::vector<std::pair<int64_t, int64_t>> input_desc_value_range; | |||||
input_desc->GetValueRange(input_desc_value_range); | |||||
if (!input_desc_value_range.empty()) { | |||||
for (const auto &range : input_desc_value_range) { | |||||
if (range.first == -1 || range.second == -1) { | |||||
GELOGD("Node %s input tensors have unknown value range, value range is %s.", node->GetName().c_str(), | |||||
formats::RangeToString(input_desc_value_range).c_str()); | |||||
has_unknown_value_range = true; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return has_unknown_value_range; | |||||
} | |||||
graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | |||||
if (src == nullptr || dst == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "While updating tensor desc, input desc is null."); | |||||
GELOGE(GRAPH_FAILED, "[Param][check] While updating tensor desc, input desc is null."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
changed = false; | |||||
std::vector<std::pair<int64_t, int64_t>> src_value_range; | |||||
std::vector<std::pair<int64_t, int64_t>> dst_value_range; | |||||
(void)src->GetValueRange(src_value_range); | |||||
(void)dst->GetValueRange(dst_value_range); | |||||
if (src_value_range != dst_value_range) { | |||||
GELOGD("While updating tensor desc, value range has been changed, src value range: %s, dst value range: %s.", | |||||
formats::RangeToString(src_value_range).c_str(), formats::RangeToString(dst_value_range).c_str()); | |||||
changed = true; | |||||
} | |||||
dst->SetValueRange(src_value_range); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus InferValueRangePass::UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, | |||||
GeTensorDescPtr &dst) { | |||||
std::vector<std::pair<int64_t, int64_t>> ref_out_tensor_value_range; | |||||
auto ref_out_tensor = src.at(0); | |||||
(void)ref_out_tensor->GetValueRange(ref_out_tensor_value_range); | |||||
for (auto &ref_tensor : src) { | |||||
std::vector<std::pair<int64_t, int64_t>> ref_tensor_value_range; | |||||
(void)ref_tensor->GetValueRange(ref_tensor_value_range); | |||||
if (ref_tensor_value_range.size() != ref_out_tensor_value_range.size()) { | |||||
GELOGD("Update TensorDesc %s failed, rank of value ranges %s and %s are not the same, skip value range refresh.", | |||||
dst->GetName().c_str(), formats::RangeToString(ref_out_tensor_value_range).c_str(), | |||||
formats::RangeToString(ref_tensor_value_range).c_str()); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
for (size_t j = 0; j < ref_out_tensor_value_range.size(); j++) { | |||||
if ((ref_out_tensor_value_range.at(j).first != ref_tensor_value_range.at(j).first) || | |||||
(ref_out_tensor_value_range.at(j).second != ref_tensor_value_range.at(j).second)) { | |||||
ref_out_tensor_value_range[j] = std::make_pair(1, -1); | |||||
} | |||||
} | |||||
} | |||||
GELOGD("While updating output desc from subgraphs, set parent node desc value range %s.", | |||||
formats::RangeToString(ref_out_tensor_value_range).c_str()); | |||||
dst->SetValueRange(ref_out_tensor_value_range); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus InferValueRangePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, | |||||
GeTensorDescPtr &dst) { | |||||
REPORT_INNER_ERROR("E19999", | |||||
"Update TensorDesc %s failed. In dynamic multi-dims size scene, there should be no value range.", | |||||
dst->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, | |||||
"[Update][TensorDesc] %s failed. In dynamic multi-dims size scene, there should be no value range.", | |||||
dst->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
graphStatus InferValueRangePass::GenerateWorstValueRange(NodePtr &node) { | |||||
GELOGI("Node %s does not run cpu kernel, because input value range has -1.", node->GetName().c_str()); | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | |||||
auto output_desc = op_desc->MutableOutputDesc(i); | |||||
if (output_desc == nullptr) { | |||||
continue; | |||||
} | |||||
auto output_i_shape = output_desc->GetShape(); | |||||
auto output_i_shape_size = output_i_shape.GetShapeSize(); | |||||
if (output_i_shape_size < 0) { | |||||
GELOGD("Node %s output shape is unknown, cannot infer value range, shape is %s.", node->GetName().c_str(), | |||||
formats::ShapeToString(output_i_shape).c_str()); | |||||
return GRAPH_NOT_CHANGED; | |||||
} | |||||
std::vector<std::pair<int64_t, int64_t>> output_i_value_range(output_i_shape_size, {1, -1}); | |||||
output_desc->SetValueRange(output_i_value_range); | |||||
GELOGD("Node %s output %zu shape is %s, the generated worst value range is %s.", node->GetName().c_str(), i, | |||||
formats::ShapeToString(output_i_shape).c_str(), formats::RangeToString(output_i_value_range).c_str()); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
template <typename T> | |||||
graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, | |||||
GeTensorPtr &output_ptr) { | |||||
std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
(void)tensor_desc.GetValueRange(value_range); | |||||
if (static_cast<int64_t>(value_range.size()) != tensor_desc.GetShape().GetShapeSize()) { | |||||
GELOGW("Value range of input %s is invalid.", tensor_desc.GetName().c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
size_t value_range_data_num = value_range.size(); | |||||
unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]()); | |||||
if (buf == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "New buf failed"); | |||||
GELOGE(MEMALLOC_FAILED, "New buf failed"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
for (size_t j = 0; j < value_range_data_num; ++j) { | |||||
auto value_range_j = use_floor_value ? value_range[j].first : value_range[j].second; | |||||
buf[j] = static_cast<T>(value_range_j); | |||||
} | |||||
if (output_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), value_range_data_num * sizeof(T)) != GRAPH_SUCCESS) { | |||||
GELOGW("Set data failed while constructing value range input tensor."); | |||||
return GRAPH_NOT_CHANGED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus InferValueRangePass::ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, | |||||
GeTensorPtr &output_ptr) { | |||||
graphStatus ret = GRAPH_SUCCESS; | |||||
auto data_type = tensor_desc.GetDataType(); | |||||
output_ptr->MutableTensorDesc().SetDataType(data_type); | |||||
switch (data_type) { | |||||
case DT_FLOAT: | |||||
ret = ConstructData<float>(tensor_desc, use_floor_value, output_ptr); | |||||
break; | |||||
case DT_DOUBLE: | |||||
ret = ConstructData<double>(tensor_desc, use_floor_value, output_ptr); | |||||
break; | |||||
case DT_UINT8: | |||||
ret = ConstructData<uint8_t>(tensor_desc, use_floor_value, output_ptr); | |||||
break; | |||||
case DT_INT8: | |||||
ret = ConstructData<int8_t>(tensor_desc, use_floor_value, output_ptr); | |||||
break; | |||||
case DT_UINT16: | |||||
ret = ConstructData<uint16_t>(tensor_desc, use_floor_value, output_ptr); | |||||
break; | |||||
case DT_INT16: | |||||
ret = ConstructData<int16_t>(tensor_desc, use_floor_value, output_ptr); | |||||
break; | |||||
case DT_INT32: | |||||
ret = ConstructData<int32_t>(tensor_desc, use_floor_value, output_ptr); | |||||
break; | |||||
case DT_INT64: | |||||
ret = ConstructData<int64_t>(tensor_desc, use_floor_value, output_ptr); | |||||
break; | |||||
default: | |||||
GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
ret = GRAPH_PARAM_INVALID; | |||||
} | |||||
return ret; | |||||
} | |||||
vector<ConstGeTensorPtr> InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) { | |||||
vector<ConstGeTensorPtr> input_tensors; | |||||
auto cur_op_desc = node->GetOpDesc(); | |||||
auto in_data_anchors = node->GetAllInDataAnchors(); | |||||
for (size_t i = 0; i < in_data_anchors.size(); ++i) { | |||||
auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto peer_node = peer_out_anchor->GetOwnerNode(); | |||||
if (peer_node == nullptr) { | |||||
continue; | |||||
} | |||||
// construct input tensor by constant node | |||||
if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { | |||||
vector<GeTensorPtr> const_weight = OpDescUtils::MutableWeights(peer_node); | |||||
if (const_weight.empty()) { | |||||
GELOGW("MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(), | |||||
peer_node->GetType().c_str()); | |||||
return vector<ConstGeTensorPtr>(); | |||||
} | |||||
// const/constant op has only one weight | |||||
if (const_weight.at(0) == nullptr) { | |||||
GELOGW("MutableWeights failed, weight of constant is null, node name: %s(%s)", | |||||
peer_node->GetName().c_str(), peer_node->GetType().c_str()); | |||||
return vector<ConstGeTensorPtr>(); | |||||
} | |||||
input_tensors.push_back(const_weight.at(0)); | |||||
GELOGD("Node %s construct input tensor %zu by constant node.", node->GetName().c_str(), input_tensors.size()); | |||||
continue; | |||||
} | |||||
// construct input tensor by boundary of value range | |||||
const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i); | |||||
GeTensorPtr tmp_tensor_ptr = MakeShared<GeTensor>(input_tensor_desc); | |||||
if (tmp_tensor_ptr == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Make shared failed"); | |||||
GELOGE(MEMALLOC_FAILED, "Make shared failed"); | |||||
return vector<ConstGeTensorPtr>(); | |||||
} | |||||
auto ret = ConstructDataByType(input_tensor_desc, use_floor_value, tmp_tensor_ptr); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGW("Construct input tensor by boundary of value range failed for input %s.", | |||||
input_tensor_desc.GetName().c_str()); | |||||
return vector<ConstGeTensorPtr>(); | |||||
} | |||||
input_tensors.push_back(tmp_tensor_ptr); | |||||
GELOGD("Node %s construct input tensor %zu by input desc value range.", node->GetName().c_str(), | |||||
input_tensors.size()); | |||||
} | |||||
return input_tensors; | |||||
} | |||||
graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) { | |||||
auto inputs = ConstructInputTensors(node, true); | |||||
if (inputs.empty()) { | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
vector<GeTensorPtr> lower_boundary_outputs; | |||||
auto ret = RunCpuKernelForValueRange(node, inputs, lower_boundary_outputs); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Node %s run cpu kernel failed while calculating value range.", node->GetName().c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
inputs = ConstructInputTensors(node, false); | |||||
if (inputs.empty()) { | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
vector<GeTensorPtr> upper_boundary_outputs; | |||||
ret = RunCpuKernelForValueRange(node, inputs, upper_boundary_outputs); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Node %s run cpu kernel failed while calculating value range.", node->GetName().c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
// construct value range from output tensor | |||||
OpDescPtr node_desc = node->GetOpDesc(); | |||||
std::vector<std::pair<int64_t, int64_t>> output_tensor_value_range; | |||||
size_t node_output_desc_size = node_desc->GetOutputsSize(); | |||||
for (size_t i = 0; i < node_output_desc_size; ++i) { | |||||
output_tensor_value_range.clear(); | |||||
auto output_tensor_desc = node_desc->MutableOutputDesc(i); | |||||
auto output_shape_size = output_tensor_desc->GetShape().GetShapeSize(); | |||||
auto lower_boundary_tensor = lower_boundary_outputs[i]; | |||||
auto lower_boundary_shape = lower_boundary_tensor->GetTensorDesc().GetShape(); | |||||
auto upper_boundary_tensor = upper_boundary_outputs[i]; | |||||
auto upper_boundary_shape = upper_boundary_tensor->GetTensorDesc().GetShape(); | |||||
if (lower_boundary_shape.GetShapeSize() != output_shape_size || | |||||
upper_boundary_shape.GetShapeSize() != output_shape_size) { | |||||
GELOGD( | |||||
"Cpu kernel result shapes %s, %s and output shape %s do not match, can not infer value range for output %s.", | |||||
formats::ShapeToString(lower_boundary_shape).c_str(), formats::ShapeToString(upper_boundary_shape).c_str(), | |||||
formats::ShapeToString(output_tensor_desc->GetShape()).c_str(), output_tensor_desc->GetName().c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
auto data_type = output_tensor_desc->GetDataType(); | |||||
switch (data_type) { | |||||
GET_DATA_BY_DTYPE(DT_INT8, int8_t) | |||||
GET_DATA_BY_DTYPE(DT_INT16, int16_t) | |||||
GET_DATA_BY_DTYPE(DT_INT32, int32_t) | |||||
GET_DATA_BY_DTYPE(DT_INT64, int64_t) | |||||
GET_DATA_BY_DTYPE(DT_UINT8, uint8_t) | |||||
GET_DATA_BY_DTYPE(DT_UINT16, uint16_t) | |||||
GET_DATA_BY_DTYPE(DT_UINT32, uint32_t) | |||||
GET_DATA_BY_DTYPE(DT_UINT64, uint64_t) | |||||
GET_DATA_BY_DTYPE(DT_FLOAT, float) | |||||
GET_DATA_BY_DTYPE(DT_DOUBLE, double) | |||||
default: | |||||
GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
output_tensor_desc->SetValueRange(output_tensor_value_range); | |||||
GELOGD("Node %s calculates output %zu value range %s by running cpu kernel.", node->GetName().c_str(), i, | |||||
formats::RangeToString(output_tensor_value_range).c_str()); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
template <typename T> | |||||
void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, | |||||
std::vector<std::pair<int64_t, int64_t>> &value_range) { | |||||
auto x = reinterpret_cast<const T *>(left_tensor->GetData().GetData()); | |||||
auto y = reinterpret_cast<const T *>(right_tensor->GetData().GetData()); | |||||
if (x == nullptr || y == nullptr) { | |||||
GELOGI("Output tensor of cpu kernel does not have data, no way to set value range."); | |||||
return; | |||||
} | |||||
for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) { | |||||
auto left = static_cast<int64_t>(*(x + j)); | |||||
auto right = static_cast<int64_t>(*(y + j)); | |||||
value_range.emplace_back(std::make_pair(left, right)); | |||||
} | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,49 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ | |||||
#define GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ | |||||
#include "graph/passes/infer_base_pass.h" | |||||
namespace ge { | |||||
class InferValueRangePass : public InferBasePass { | |||||
public: | |||||
graphStatus Infer(NodePtr &node) override; | |||||
private: | |||||
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; | |||||
graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; | |||||
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override; | |||||
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, | |||||
GeTensorDescPtr &dst) override; | |||||
bool NeedInfer(const NodePtr &node) const override; | |||||
bool InputIsDynamic(const NodePtr &node) const; | |||||
bool InputIsConstOrHasValueRange(const NodePtr &node) const; | |||||
bool InputHasUnknownValueRange(const NodePtr &node) const; | |||||
graphStatus GenerateWorstValueRange(NodePtr &node); | |||||
template <typename T> | |||||
graphStatus ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); | |||||
graphStatus ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); | |||||
vector<ConstGeTensorPtr> ConstructInputTensors(const NodePtr &node, bool use_floor_value); | |||||
template <typename T> | |||||
void ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, | |||||
std::vector<std::pair<int64_t, int64_t>> &value_range); | |||||
graphStatus ConstructInputAndInferValueRange(NodePtr &node); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ |
@@ -54,6 +54,7 @@ | |||||
#include "graph/passes/hccl_group_pass.h" | #include "graph/passes/hccl_group_pass.h" | ||||
#include "graph/passes/identity_pass.h" | #include "graph/passes/identity_pass.h" | ||||
#include "graph/passes/infershape_pass.h" | #include "graph/passes/infershape_pass.h" | ||||
#include "graph/passes/infer_value_range_pass.h" | |||||
#include "graph/passes/merge_pass.h" | #include "graph/passes/merge_pass.h" | ||||
#include "graph/passes/net_output_pass.h" | #include "graph/passes/net_output_pass.h" | ||||
#include "graph/passes/no_use_reshape_remove_pass.h" | #include "graph/passes/no_use_reshape_remove_pass.h" | ||||
@@ -2016,6 +2017,8 @@ Status GraphPrepare::InferShapeForPreprocess() { | |||||
names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | ||||
ConstantFoldingPass constant_folding_pass; | ConstantFoldingPass constant_folding_pass; | ||||
names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | ||||
InferValueRangePass infer_value_pass; | |||||
names_to_passes.emplace_back("InferValuePass", &infer_value_pass); | |||||
int32_t dev_count = 0; | int32_t dev_count = 0; | ||||
AicpuConstantFoldingPass aicpu_constant_folding_pass; | AicpuConstantFoldingPass aicpu_constant_folding_pass; | ||||
@@ -1 +1 @@ | |||||
Subproject commit 2ad00e17886fd06c0d00f8a8cf370783a3d31818 | |||||
Subproject commit 9e4a51a9602195b82e326b853f5adbfefc3972b6 |
@@ -221,7 +221,9 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | ||||
@@ -535,7 +537,9 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" | |||||
"${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | ||||
"${GE_CODE_DIR}/ge/analyzer/analyzer.cc" | "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | ||||
@@ -662,6 +666,8 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES | |||||
) | ) | ||||
set(PASS_TEST_FILES | set(PASS_TEST_FILES | ||||
"graph/passes/infer_value_range_pass_unittest.cc" | |||||
"graph/passes/infer_base_pass_unittest.cc" | |||||
"graph/passes/prune_pass_unittest.cc" | "graph/passes/prune_pass_unittest.cc" | ||||
"graph/passes/enter_pass_unittest.cc" | "graph/passes/enter_pass_unittest.cc" | ||||
"graph/passes/switch_op_pass_unittest.cc" | "graph/passes/switch_op_pass_unittest.cc" | ||||
@@ -720,7 +726,6 @@ set(PASS_TEST_FILES | |||||
"graph/passes/memcpy_addr_async_unittest.cc" | "graph/passes/memcpy_addr_async_unittest.cc" | ||||
"graph/passes/hccl_continuous_pass_unittest.cc" | "graph/passes/hccl_continuous_pass_unittest.cc" | ||||
"graph/passes/hccl_memcpy_pass_unittest.cc" | "graph/passes/hccl_memcpy_pass_unittest.cc" | ||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -0,0 +1,359 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "graph/passes/infer_base_pass.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph_builder_utils.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
class ChildPassBuilder; | |||||
static const char *kInferTimes = "infer_times"; | |||||
class InferBasePassStub : public InferBasePass { | |||||
public: | |||||
friend class ChildPassBuilder; | |||||
graphStatus Infer(NodePtr &node) override{ | |||||
call_infer_times++; | |||||
for (size_t i = 0; i < node->GetOutDataNodesSize(); ++i) { | |||||
auto output_td = node->GetOpDesc()->MutableOutputDesc(i); | |||||
int times = 0; | |||||
AttrUtils::GetInt(output_td, kInferTimes, times); | |||||
AttrUtils::SetInt(output_td, kInferTimes, times + 1); | |||||
} | |||||
return infer_result_; | |||||
}; | |||||
int32_t call_infer_times = 0; | |||||
int32_t call_update_tensor_desc_times = 0; | |||||
int32_t call_update_from_subgraph_times = 0; | |||||
int32_t call_update_from_subgraph_multi_dims_times = 0; | |||||
std::vector<std::pair<GeTensorDescPtr, GeTensorDescPtr>> update_td_pairs; | |||||
private: | |||||
bool NeedInfer(const NodePtr &node) const override { | |||||
return need_infer_; | |||||
}; | |||||
std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override { return "test SerialTensorInfo"; }; | |||||
graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override { | |||||
call_update_tensor_desc_times++; | |||||
changed = td_changed_; | |||||
int times = 0; | |||||
if (AttrUtils::GetInt(src, kInferTimes, times)) { | |||||
AttrUtils::SetInt(dst, kInferTimes, times); | |||||
} | |||||
update_td_pairs.emplace_back(src, dst); | |||||
return GRAPH_SUCCESS; | |||||
}; | |||||
graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override { | |||||
call_update_from_subgraph_times++; | |||||
return GRAPH_SUCCESS; | |||||
}; | |||||
graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src, | |||||
GeTensorDescPtr &dst) override { | |||||
call_update_from_subgraph_multi_dims_times++; | |||||
return GRAPH_SUCCESS; | |||||
}; | |||||
bool td_changed_; | |||||
bool need_infer_; | |||||
graphStatus infer_result_; | |||||
}; | |||||
class ChildPassBuilder { | |||||
public: | |||||
ChildPassBuilder &SetNeedInferFlag(bool flag) { | |||||
need_infer_ = flag; | |||||
return *this; | |||||
} | |||||
ChildPassBuilder &SetInferResult(graphStatus ret) { | |||||
infer_result_ = ret; | |||||
return *this; | |||||
} | |||||
ChildPassBuilder &SetTdChangedFlag(bool changed_flag) { | |||||
td_changed_ = changed_flag; | |||||
return *this; | |||||
} | |||||
InferBasePassStub Build() { | |||||
InferBasePassStub ib; | |||||
ib.td_changed_ = td_changed_; | |||||
ib.need_infer_ = need_infer_; | |||||
ib.infer_result_ = infer_result_; | |||||
return ib; | |||||
} | |||||
private: | |||||
bool td_changed_ = false; | |||||
bool need_infer_ = true; | |||||
graphStatus infer_result_ = GRAPH_SUCCESS; | |||||
}; | |||||
class UtestGraphInferBasePassStub : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
/* | |||||
* data1 data2 | |||||
* \ / | |||||
* sub1 | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ut::GraphBuilder TestSubgraphBuilder() { | |||||
ut::GraphBuilder builder = ut::GraphBuilder("branch_graph"); | |||||
std::vector<int64_t> shape1 = {1,1}; | |||||
auto data1 = builder.AddNode("data1_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1); | |||||
auto data1_desc = data1->GetOpDesc(); | |||||
EXPECT_NE(data1_desc, nullptr); | |||||
AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); | |||||
std::vector<int64_t> shape2 = {2,2}; | |||||
auto data2 = builder.AddNode("data2_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2); | |||||
auto data2_desc = data2->GetOpDesc(); | |||||
EXPECT_NE(data2_desc, nullptr); | |||||
AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); | |||||
auto sub1 = builder.AddNode("Sub", "Sub", 2, 1); | |||||
std::vector<int64_t> shape7 = {8,8}; | |||||
auto netoutput = builder.AddNode("output", NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7); | |||||
auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); | |||||
EXPECT_NE(input0_desc, nullptr); | |||||
AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); | |||||
builder.AddDataEdge(data1, 0, sub1, 0); | |||||
builder.AddDataEdge(data2, 0, sub1, 1); | |||||
builder.AddDataEdge(sub1, 0, netoutput, 0); | |||||
return builder; | |||||
} | |||||
/* | |||||
* data1 data2 | |||||
* \ / | |||||
* case1 | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ut::GraphBuilder RootGraphBuilder() { | |||||
ut::GraphBuilder builder = ut::GraphBuilder("root_graph"); | |||||
auto data1 = builder.AddNode("data1", "Data", 0, 1); | |||||
auto data2 = builder.AddNode("data2", "Data", 0, 1); | |||||
auto case1 = builder.AddNode("case1", CASE, 2, 1); | |||||
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
builder.AddDataEdge(data1, 0, case1, 0); | |||||
builder.AddDataEdge(data2, 0, case1, 1); | |||||
builder.AddDataEdge(case1, 0, netoutput, 0); | |||||
auto parent_graph = builder.GetGraph(); | |||||
auto subgraph_builder = TestSubgraphBuilder(); | |||||
auto subgraph = subgraph_builder.GetGraph(); | |||||
case1->GetOpDesc()->AddSubgraphName(subgraph->GetName()); | |||||
case1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName()); | |||||
subgraph->SetParentNode(case1); | |||||
subgraph->SetParentGraph(parent_graph); | |||||
EXPECT_EQ(parent_graph->AddSubgraph(subgraph->GetName(), subgraph), GRAPH_SUCCESS); | |||||
return builder; | |||||
} | |||||
/* | |||||
* data1 data2 | |||||
* \ / | |||||
* add1 | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ut::GraphBuilder NoSubgraphBuilder() { | |||||
ut::GraphBuilder builder = ut::GraphBuilder("no_subgraph"); | |||||
auto data1 = builder.AddNode("data1", "Data", 0, 1); | |||||
auto data2 = builder.AddNode("data2", "Data", 0, 1); | |||||
auto add1 = builder.AddNode("add1", ADD, 2, 1); | |||||
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
builder.AddDataEdge(data1, 0, add1, 0); | |||||
builder.AddDataEdge(data2, 0, add1, 1); | |||||
builder.AddDataEdge(add1, 0, netoutput, 0); | |||||
return builder; | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, CallInfer_WhenNeedInferReturnTrue) { | |||||
auto builder = NoSubgraphBuilder(); | |||||
auto test_graph = builder.GetGraph(); | |||||
auto add_node = test_graph->FindNode("add1"); | |||||
EXPECT_NE(add_node, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.Build(); | |||||
// NeedInfer return true | |||||
EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); | |||||
EXPECT_EQ(stub_base_pass.call_infer_times, 1); | |||||
int times = -1; | |||||
EXPECT_TRUE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); | |||||
EXPECT_EQ(times, 1); | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, NotCallInfer_WhenNeedInferReturnFalse) { | |||||
auto builder = NoSubgraphBuilder(); | |||||
auto test_graph = builder.GetGraph(); | |||||
auto add_node = test_graph->FindNode("add1"); | |||||
EXPECT_NE(add_node, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.SetNeedInferFlag(false).Build(); | |||||
// NeedInfer return false | |||||
EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); | |||||
EXPECT_EQ(stub_base_pass.call_infer_times, 0); | |||||
int times = -1; | |||||
EXPECT_FALSE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, NotAddCurNodeRepass_CallUpdatePeerNode_WhenInferReturnSuccess) { | |||||
auto builder = NoSubgraphBuilder(); | |||||
auto test_graph = builder.GetGraph(); | |||||
auto add_node = test_graph->FindNode("add1"); | |||||
auto netoutput = test_graph->FindNode("netoutput"); | |||||
EXPECT_NE(add_node, nullptr); | |||||
EXPECT_NE(netoutput, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.Build(); | |||||
EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); | |||||
EXPECT_EQ(stub_base_pass.call_infer_times, 1); | |||||
EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); | |||||
std::vector<std::pair<GeTensorDescPtr, GeTensorDescPtr>> expected_updated_tensor_desc_pairs = { | |||||
{add_node->GetOpDesc()->MutableOutputDesc(0), netoutput->GetOpDesc()->MutableInputDesc(0)}}; | |||||
EXPECT_EQ(stub_base_pass.update_td_pairs, expected_updated_tensor_desc_pairs); | |||||
EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set<NodePtr>({})); | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, AddCurNodeRepass_NotCallUpdatePeerNode_WhenInferReturnNeedRepass) { | |||||
auto builder = NoSubgraphBuilder(); | |||||
auto test_graph = builder.GetGraph(); | |||||
auto add_node = test_graph->FindNode("add1"); | |||||
EXPECT_NE(add_node, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.SetInferResult(GRAPH_NODE_NEED_REPASS).Build(); | |||||
// do re_pass | |||||
EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); | |||||
EXPECT_EQ(stub_base_pass.call_infer_times, 1); | |||||
EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 0); | |||||
EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set<NodePtr>({add_node})); | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, NotAddPeerNodeRepass_AfterUpdatePeerNode_WhenUnchanged) { | |||||
auto builder = NoSubgraphBuilder(); | |||||
auto test_graph = builder.GetGraph(); | |||||
auto add_node = test_graph->FindNode("add1"); | |||||
auto netoutput = test_graph->FindNode("netoutput"); | |||||
EXPECT_NE(add_node, nullptr); | |||||
EXPECT_NE(netoutput, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.Build(); | |||||
EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); | |||||
EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); | |||||
EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set<NodePtr>({})); | |||||
int times = -1; | |||||
EXPECT_TRUE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); | |||||
EXPECT_EQ(times, 1); | |||||
times = -1; | |||||
EXPECT_TRUE(AttrUtils::GetInt(netoutput->GetOpDesc()->GetInputDescPtr(0), kInferTimes, times)); | |||||
EXPECT_EQ(times, 1); | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, AddPeerNodeRepass_AfterUpdatePeerNode_WhenChanged) { | |||||
auto builder = NoSubgraphBuilder(); | |||||
auto test_graph = builder.GetGraph(); | |||||
auto add_node = test_graph->FindNode("add1"); | |||||
auto netoutput = test_graph->FindNode("netoutput"); | |||||
EXPECT_NE(add_node, nullptr); | |||||
EXPECT_NE(netoutput, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.SetTdChangedFlag(true).Build(); | |||||
EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); | |||||
EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); | |||||
EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set<NodePtr>({netoutput})); | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, TestUpdateSubgraphData_WhenBeforeSubgraph) { | |||||
auto builder = RootGraphBuilder(); | |||||
auto parent_graph = builder.GetGraph(); | |||||
auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
EXPECT_EQ(subgraphs.size(), 1); | |||||
auto case_node = parent_graph->FindNode("case1"); | |||||
auto data1 = subgraphs[0]->FindNode("data1_1"); | |||||
auto data2 = subgraphs[0]->FindNode("data2_1"); | |||||
EXPECT_NE(case_node, nullptr); | |||||
EXPECT_NE(data1, nullptr); | |||||
EXPECT_NE(data2, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.SetInferResult(GRAPH_NODE_NEED_REPASS).Build(); | |||||
EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); | |||||
// when GRAPH_NODE_NEED_REPASS, not update peer node, only update two data, update input and output, 2*2 | |||||
EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 4); | |||||
std::vector<std::pair<GeTensorDescPtr, GeTensorDescPtr>> expected_updated_tensor_desc_pairs = { | |||||
{case_node->GetOpDesc()->MutableInputDesc(0), data1->GetOpDesc()->MutableInputDesc(0)}, | |||||
{case_node->GetOpDesc()->MutableInputDesc(0), data1->GetOpDesc()->MutableOutputDesc(0)}, | |||||
{case_node->GetOpDesc()->MutableInputDesc(1), data2->GetOpDesc()->MutableInputDesc(0)}, | |||||
{case_node->GetOpDesc()->MutableInputDesc(1), data2->GetOpDesc()->MutableOutputDesc(0)}, | |||||
}; | |||||
EXPECT_EQ(stub_base_pass.update_td_pairs, expected_updated_tensor_desc_pairs); | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, TestUpdateParentNodeOutput_WhenAfterSubgraph) { | |||||
auto builder = RootGraphBuilder(); | |||||
auto parent_graph = builder.GetGraph(); | |||||
auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
EXPECT_EQ(subgraphs.size(), 1); | |||||
auto case_node = parent_graph->FindNode("case1"); | |||||
EXPECT_NE(case_node, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.Build(); | |||||
stub_base_pass.SetOption(kOptimizeAfterSubGraph, ""); | |||||
EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); | |||||
EXPECT_EQ(stub_base_pass.call_update_from_subgraph_times, 1); | |||||
EXPECT_EQ(stub_base_pass.call_update_from_subgraph_multi_dims_times, 0); | |||||
} | |||||
TEST_F(UtestGraphInferBasePassStub, TestUpdateParentNodeOutputForMultiDims_WhenAfterSubgraph) { | |||||
auto builder = RootGraphBuilder(); | |||||
auto parent_graph = builder.GetGraph(); | |||||
auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
EXPECT_EQ(subgraphs.size(), 1); | |||||
auto case_node = parent_graph->FindNode("case1"); | |||||
auto set_ret = AttrUtils::SetInt(case_node->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); | |||||
EXPECT_EQ(set_ret, true); | |||||
EXPECT_NE(case_node, nullptr); | |||||
ChildPassBuilder pass_builder; | |||||
auto stub_base_pass = pass_builder.Build(); | |||||
stub_base_pass.SetOption(kOptimizeAfterSubGraph, ""); | |||||
EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); | |||||
EXPECT_EQ(stub_base_pass.call_update_from_subgraph_times, 0); | |||||
EXPECT_EQ(stub_base_pass.call_update_from_subgraph_multi_dims_times, 1); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,583 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/passes/infer_value_range_pass.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph_builder_utils.h" | |||||
#include "inc/external/graph/operator_reg.h" | |||||
#include "inc/external/graph/operator.h" | |||||
#include "inc/external/graph/operator_factory.h" | |||||
#include "inc/graph/operator_factory_impl.h" | |||||
#include "inc/kernel.h" | |||||
#include "inc/kernel_factory.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
class UtestGraphInferValueRangePass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
/* | |||||
* data1 const1 | |||||
* \ / | |||||
* case1 | |||||
* | | |||||
* relu10 | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ut::GraphBuilder ParentGraphBuilder() { | |||||
ut::GraphBuilder builder = ut::GraphBuilder("g1"); | |||||
auto data1 = builder.AddNode("data1", "Data", 0, 1); | |||||
std::vector<int64_t> const_shape = {1}; | |||||
auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_INT32, const_shape); | |||||
auto case1 = builder.AddNode("case1", CASE, 2, 1); | |||||
auto relu1 = builder.AddNode("relu10", "Relu", 1, 1); | |||||
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
int32_t weight[1] = {1}; | |||||
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); | |||||
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight)); | |||||
OpDescUtils::SetWeights(const1, {tensor}); | |||||
auto case_in0_shape = GeShape({1, 1,-1, 224}); | |||||
auto case_in1_shape = GeShape({1,1}); | |||||
std::vector<std::pair<int64_t, int64_t>> in0_range = {make_pair(1, 1), make_pair(1, 1), | |||||
make_pair(1, -1), make_pair(1, 224)}; | |||||
std::vector<std::pair<int64_t, int64_t>> in1_range = {make_pair(1, 100), make_pair(1, 10)}; | |||||
case1->GetOpDesc()->MutableInputDesc(0)->SetShape(case_in0_shape); | |||||
case1->GetOpDesc()->MutableInputDesc(0)->SetValueRange(in0_range); | |||||
case1->GetOpDesc()->MutableInputDesc(1)->SetShape(case_in1_shape); | |||||
case1->GetOpDesc()->MutableInputDesc(1)->SetValueRange(in1_range); | |||||
builder.AddDataEdge(data1, 0, case1, 0); | |||||
builder.AddDataEdge(const1, 0, case1, 1); | |||||
builder.AddDataEdge(case1, 0, relu1, 0); | |||||
builder.AddDataEdge(relu1, 0, netoutput, 0); | |||||
return builder; | |||||
} | |||||
/* | |||||
* data1 data2 | |||||
* \ / | |||||
* switch | |||||
* / \ | |||||
* relu1 relu2 | |||||
* \ / | |||||
* merge | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) { | |||||
ut::GraphBuilder builder = ut::GraphBuilder(graph_name); | |||||
std::vector<int64_t> shape1 = {2,2}; | |||||
string data1_name = "data1_" + std::to_string(num); | |||||
auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1); | |||||
auto data1_desc = data1->GetOpDesc(); | |||||
EXPECT_NE(data1_desc, nullptr); | |||||
AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); | |||||
std::vector<int64_t> shape2 = {3,3}; | |||||
string data2_name = "data2_" + std::to_string(num); | |||||
auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2); | |||||
auto data2_desc = data2->GetOpDesc(); | |||||
EXPECT_NE(data2_desc, nullptr); | |||||
AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); | |||||
string switch_name = "switch_" + std::to_string(num); | |||||
auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); | |||||
string relu1_name = "relu1_" + std::to_string(num); | |||||
auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1); | |||||
string relu2_name = "relu2_" + std::to_string(num); | |||||
auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1); | |||||
string merge_name = "merge_" + std::to_string(num); | |||||
auto merge = builder.AddNode(merge_name, "Merge", 2, 1); | |||||
std::vector<int64_t> shape7 = {8,8}; | |||||
string output_name = "output_" + std::to_string(num); | |||||
auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7); | |||||
auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); | |||||
EXPECT_NE(input0_desc, nullptr); | |||||
AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); | |||||
std::vector<std::pair<int64_t, int64_t>> range = {make_pair(1, -1), make_pair(1, -1)}; | |||||
input0_desc->SetValueRange(range); | |||||
builder.AddDataEdge(data1, 0, switch1, 0); | |||||
builder.AddDataEdge(data2, 0, switch1, 1); | |||||
builder.AddDataEdge(switch1, 0, relu1, 0); | |||||
builder.AddDataEdge(switch1, 1, relu2, 0); | |||||
builder.AddDataEdge(relu1, 0, merge, 0); | |||||
builder.AddDataEdge(relu2, 0, merge, 1); | |||||
builder.AddDataEdge(merge, 0, netoutput, 0); | |||||
return builder; | |||||
} | |||||
void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { | |||||
auto case_node = parent_graph->FindNode("case1"); | |||||
EXPECT_NE(case_node, nullptr); | |||||
for (uint32_t i = 0; i < branch_num; ++i) { | |||||
string name = "Branch_Graph_" + std::to_string(i); | |||||
auto builder_subgraph = SwitchSubgraphBuilder(name, i); | |||||
auto switch_subgraph = builder_subgraph.GetGraph(); | |||||
case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); | |||||
case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); | |||||
switch_subgraph->SetParentNode(case_node); | |||||
switch_subgraph->SetParentGraph(parent_graph); | |||||
EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); | |||||
} | |||||
} | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UnregisteredNodeType) { | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
auto addn_op_desc = std::make_shared<OpDesc>("AddN", "AddN"); | |||||
addn_op_desc->AddInputDesc(ge_tensor_desc); | |||||
addn_op_desc->AddOutputDesc(ge_tensor_desc); | |||||
auto addn_op_node = graph->AddNode(addn_op_desc); | |||||
InferValueRangePass infer_pass; | |||||
EXPECT_EQ(infer_pass.Run(addn_op_node), SUCCESS); | |||||
} | |||||
auto ShapeValueInfer = [&](Operator &op) { | |||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
auto output_tensor_desc = op_desc->MutableOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> in_shape_range; | |||||
op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); | |||||
if (!in_shape_range.empty()) { | |||||
output_tensor_desc->SetValueRange(in_shape_range); | |||||
} | |||||
return SUCCESS; | |||||
}; | |||||
REG_OP(Shape) | |||||
.OP_END_FACTORY_REG(Shape) | |||||
IMPL_INFER_VALUE_RANGE_FUNC(Shape, ShapeValueRangeFunc){ | |||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
auto output_tensor_desc = op_desc->MutableOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> in_shape_range; | |||||
op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); | |||||
if (!in_shape_range.empty()) { | |||||
output_tensor_desc->SetValueRange(in_shape_range); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseRegistedFunc_NotInfer) { | |||||
INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueRangeFunc); | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range = {make_pair(1, 1), make_pair(1, 1), | |||||
make_pair(4, 4), make_pair(192, 192)}; | |||||
ge_tensor_desc.SetShapeRange(shape_range); | |||||
GeTensorDesc output_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); | |||||
auto op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
op_desc->AddInputDesc(ge_tensor_desc); | |||||
op_desc->AddOutputDesc(output_tensor_desc); | |||||
auto op_node = graph->AddNode(op_desc); | |||||
InferValueRangePass infer_pass; | |||||
EXPECT_EQ(infer_pass.Run(op_node), SUCCESS); | |||||
auto output_0_desc = op_node->GetOpDesc()->GetOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
output_0_desc.GetValueRange(value_range); | |||||
EXPECT_EQ(value_range.empty(), true); | |||||
} | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseRegistedFunc_DoInfer) { | |||||
// sqrt -> shape -> Output | |||||
INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueRangeFunc); | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
GeTensorDesc sqrt_tensor_desc(GeShape({-1, -1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); | |||||
std::vector<std::pair<int64_t, int64_t>> shape_range = {make_pair(1, 100), make_pair(1, 240), | |||||
make_pair(4, 4), make_pair(192, 192)}; | |||||
sqrt_tensor_desc.SetShapeRange(shape_range); | |||||
auto sqrt_op_desc = std::make_shared<OpDesc>("Sqrt", "Sqrt"); | |||||
sqrt_op_desc->AddInputDesc(sqrt_tensor_desc); | |||||
sqrt_op_desc->AddOutputDesc(sqrt_tensor_desc); | |||||
auto sqrt_node = graph->AddNode(sqrt_op_desc); | |||||
GeTensorDesc shape_output_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); | |||||
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
shape_op_desc->AddInputDesc(sqrt_tensor_desc); | |||||
shape_op_desc->AddOutputDesc(shape_output_desc); | |||||
auto shape_node = graph->AddNode(shape_op_desc); | |||||
GeTensorDesc Output_in_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
auto Output_op_desc = std::make_shared<OpDesc>("Output", "Output"); | |||||
Output_op_desc->AddInputDesc(Output_in_tensor_desc); | |||||
auto Output_node = graph->AddNode(Output_op_desc); | |||||
ge::GraphUtils::AddEdge(sqrt_node->GetOutDataAnchor(0), shape_node->GetInDataAnchor(0)); | |||||
ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), Output_node->GetInDataAnchor(0)); | |||||
EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); | |||||
InferValueRangePass infer_pass; | |||||
auto ret = infer_pass.Run(shape_node); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
auto output_0_desc = shape_node->GetOpDesc()->GetOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
output_0_desc.GetValueRange(value_range); | |||||
EXPECT_EQ(value_range.size(), 4); | |||||
std::vector<int64_t> target_value_range = {1, 100, 1, 240, 4, 4, 192, 192}; | |||||
std::vector<int64_t> output_value_range; | |||||
for (auto pair : value_range) { | |||||
output_value_range.push_back(pair.first); | |||||
output_value_range.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(target_value_range, output_value_range); | |||||
auto in_0_desc = Output_node->GetOpDesc()->GetInputDesc(0); | |||||
value_range.clear(); | |||||
in_0_desc.GetValueRange(value_range); | |||||
EXPECT_EQ(value_range.size(), 4); | |||||
output_value_range.clear(); | |||||
for (auto pair : value_range) { | |||||
output_value_range.push_back(pair.first); | |||||
output_value_range.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(target_value_range, output_value_range); | |||||
} | |||||
class AddKernel : public Kernel { | |||||
public: | |||||
Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input, | |||||
std::vector<ge::GeTensorPtr> &v_output) override { | |||||
if (input[0]->GetTensorDesc().GetDataType() == DT_INT64 || input[0]->GetTensorDesc().GetDataType() == DT_UINT64) { | |||||
vector<int64_t> data_vec; | |||||
auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); | |||||
auto x1_data = reinterpret_cast<const int64_t *>(input[0]->GetData().data()); | |||||
auto x2_data = reinterpret_cast<const int64_t *>(input[1]->GetData().data()); | |||||
for (size_t i = 0; i < data_num; i++) { | |||||
auto x_index = *(x1_data + i); | |||||
auto y_index = *(x2_data + i); | |||||
data_vec.push_back(x_index + y_index); | |||||
} | |||||
GeTensorPtr const_tensor = std::make_shared<ge::GeTensor>(input[0]->GetTensorDesc(), (uint8_t *)data_vec.data(), | |||||
data_num * sizeof(int64_t)); | |||||
v_output.emplace_back(const_tensor); | |||||
return SUCCESS; | |||||
} else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) { | |||||
vector<int32_t> data_vec; | |||||
auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); | |||||
auto x1_data = reinterpret_cast<const int32_t *>(input[0]->GetData().data()); | |||||
auto x2_data = reinterpret_cast<const int32_t *>(input[1]->GetData().data()); | |||||
for (size_t i = 0; i < data_num; i++) { | |||||
auto x_index = *(x1_data + i); | |||||
auto y_index = *(x2_data + i); | |||||
data_vec.push_back(x_index + y_index); | |||||
} | |||||
GeTensorPtr const_tensor = std::make_shared<ge::GeTensor>(input[0]->GetTensorDesc(), (uint8_t *)data_vec.data(), | |||||
data_num * sizeof(int32_t)); | |||||
v_output.emplace_back(const_tensor); | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
}; | |||||
REGISTER_KERNEL(ADD, AddKernel); | |||||
INFER_VALUE_RANGE_DEFAULT_REG(Add); | |||||
INFER_VALUE_RANGE_DEFAULT_REG(Sqrt); | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveUnKnownValueRange) { | |||||
// shape --- add --- sqrt | |||||
// constant / | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
vector<int64_t> dims_vec = {4}; | |||||
vector<int64_t> data_vec = {1, 1, 1, 1}; | |||||
GeTensorDesc const_tensor_desc(ge::GeShape(dims_vec), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
GeTensorPtr const_tensor = | |||||
std::make_shared<ge::GeTensor>(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t)); | |||||
auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant"); | |||||
const_op_desc->AddOutputDesc(const_tensor_desc); | |||||
EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); | |||||
auto const_node = graph->AddNode(const_op_desc); | |||||
GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
std::vector<std::pair<int64_t, int64_t>> unknown_value_range = {make_pair(1, -1), make_pair(1, 240), | |||||
make_pair(4, 4), make_pair(192, 192)}; | |||||
shape_tensor_desc.SetValueRange(unknown_value_range); | |||||
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
shape_op_desc->AddOutputDesc(shape_tensor_desc); | |||||
auto shape_node = graph->AddNode(shape_op_desc); | |||||
GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add"); | |||||
add_op_desc->AddInputDesc(shape_tensor_desc); | |||||
add_op_desc->AddInputDesc(const_tensor_desc); | |||||
add_op_desc->AddOutputDesc(add_tensor_desc); | |||||
auto add_node = graph->AddNode(add_op_desc); | |||||
ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); | |||||
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); | |||||
// test unknown value range | |||||
InferValueRangePass infer_pass; | |||||
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); | |||||
auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> out_value_range; | |||||
output_0_desc.GetValueRange(out_value_range); | |||||
EXPECT_EQ(out_value_range.size(), 4); | |||||
std::vector<int64_t> unknown_target_value_range = {1, -1, 1, -1, 1, -1, 1, -1}; | |||||
std::vector<int64_t> output_value_range; | |||||
for (auto pair : out_value_range) { | |||||
output_value_range.push_back(pair.first); | |||||
output_value_range.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(unknown_target_value_range, output_value_range); | |||||
} | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { | |||||
// shape --- add --- sqrt | |||||
// constant / | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
vector<int64_t> dims_vec = {4}; | |||||
vector<int64_t> data_vec = {1, 1, 1, 1}; | |||||
GeTensorDesc const_tensor_desc(ge::GeShape(dims_vec), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
GeTensorPtr const_tensor = | |||||
std::make_shared<ge::GeTensor>(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t)); | |||||
auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant"); | |||||
const_op_desc->AddOutputDesc(const_tensor_desc); | |||||
EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); | |||||
auto const_node = graph->AddNode(const_op_desc); | |||||
GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
std::vector<std::pair<int64_t, int64_t>> unknown_value_range = {make_pair(1, 100), make_pair(1, 240), | |||||
make_pair(4, 4), make_pair(192, 192)}; | |||||
shape_tensor_desc.SetValueRange(unknown_value_range); | |||||
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
shape_op_desc->AddOutputDesc(shape_tensor_desc); | |||||
auto shape_node = graph->AddNode(shape_op_desc); | |||||
GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add"); | |||||
add_op_desc->AddInputDesc(shape_tensor_desc); | |||||
add_op_desc->AddInputDesc(const_tensor_desc); | |||||
add_op_desc->AddOutputDesc(add_tensor_desc); | |||||
auto add_node = graph->AddNode(add_op_desc); | |||||
auto sqrt_op_desc = std::make_shared<OpDesc>("Sqrt", "Sqrt"); | |||||
sqrt_op_desc->AddInputDesc(GeTensorDesc()); | |||||
auto sqrt_node = graph->AddNode(sqrt_op_desc); | |||||
ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); | |||||
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); | |||||
ge::GraphUtils::AddEdge(add_node->GetOutDataAnchor(0), sqrt_node->GetInDataAnchor(1)); | |||||
InferValueRangePass infer_pass; | |||||
EXPECT_EQ(infer_pass.Run(sqrt_node), SUCCESS); | |||||
// test known value range | |||||
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); | |||||
auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> out_value_range; | |||||
output_0_desc.GetValueRange(out_value_range); | |||||
EXPECT_EQ(out_value_range.size(), 4); | |||||
std::vector<int64_t> target_value_range = {2, 101, 2, 241, 5, 5, 193, 193}; | |||||
std::vector<int64_t> output_value_range; | |||||
for (auto pair : out_value_range) { | |||||
output_value_range.push_back(pair.first); | |||||
output_value_range.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(target_value_range, output_value_range); | |||||
} | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int32) { | |||||
// shape --- add --- sqrt | |||||
// constant / | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
vector<int32_t> data_vec = {1, 100, 2, 200}; | |||||
GeTensorDesc const_tensor_desc(ge::GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
GeTensorPtr const_tensor = | |||||
std::make_shared<ge::GeTensor>(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int32_t)); | |||||
auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant"); | |||||
const_op_desc->AddOutputDesc(const_tensor_desc); | |||||
EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); | |||||
auto const_node = graph->AddNode(const_op_desc); | |||||
GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
std::vector<std::pair<int64_t, int64_t>> known_value_range = {make_pair(1, 100), make_pair(1, 240), | |||||
make_pair(4, 4), make_pair(192, 192)}; | |||||
shape_tensor_desc.SetValueRange(known_value_range); | |||||
auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
shape_op_desc->AddOutputDesc(shape_tensor_desc); | |||||
auto shape_node = graph->AddNode(shape_op_desc); | |||||
GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
auto add_op_desc = std::make_shared<OpDesc>("Add", "Add"); | |||||
add_op_desc->AddInputDesc(shape_tensor_desc); | |||||
add_op_desc->AddInputDesc(const_tensor_desc); | |||||
add_op_desc->AddOutputDesc(add_tensor_desc); | |||||
auto add_node = graph->AddNode(add_op_desc); | |||||
ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); | |||||
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); | |||||
InferValueRangePass infer_pass; | |||||
EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); | |||||
auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> out_value_range; | |||||
output_0_desc.GetValueRange(out_value_range); | |||||
EXPECT_EQ(out_value_range.size(), 4); | |||||
std::vector<int64_t> target_value_range = {2, 101, 101, 340, 6, 6, 392, 392}; | |||||
std::vector<int64_t> output_value_range; | |||||
for (auto pair : out_value_range) { | |||||
output_value_range.push_back(pair.first); | |||||
output_value_range.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(target_value_range, output_value_range); | |||||
} | |||||
REG_OP(Case) | |||||
.OP_END_FACTORY_REG(Case) | |||||
IMPL_INFER_VALUE_RANGE_FUNC(Case, ValueRangeFunc){ | |||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
auto output_tensor_desc = op_desc->MutableOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> in_value_range; | |||||
output_tensor_desc->GetValueRange(in_value_range); | |||||
if (in_value_range.empty()) { | |||||
std::vector<std::pair<int64_t, int64_t>> out_value_range = {make_pair(1, 2), make_pair(1, 3), | |||||
make_pair(1, 4), make_pair(1, 5)};; | |||||
output_tensor_desc->SetValueRange(out_value_range); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Case, INPUT_HAS_VALUE_RANGE, ValueRangeFunc); | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_HasCaeSubgraph_WhenBeforeSubgraph) { | |||||
auto builder = ParentGraphBuilder(); | |||||
auto parent_graph = builder.GetGraph(); | |||||
AddCaseSubgraph(parent_graph, 2); | |||||
auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
EXPECT_EQ(subgraphs.size(), 2); | |||||
// check before subgraph | |||||
auto case_node = parent_graph->FindNode("case1"); | |||||
EXPECT_NE(case_node, nullptr); | |||||
InferValueRangePass infer_pass; | |||||
EXPECT_EQ(infer_pass.Run(case_node), SUCCESS); | |||||
auto case_out_0_desc = case_node->GetOpDesc()->MutableOutputDesc(0); | |||||
std::vector<std::pair<int64_t, int64_t>> out_value_range; | |||||
case_out_0_desc->GetValueRange(out_value_range); | |||||
EXPECT_EQ(out_value_range.size(), 4); | |||||
std::vector<int64_t> target_value_range = {1,2,1,3,1,4,1,5}; | |||||
std::vector<int64_t> output_value_range_list; | |||||
for (auto pair : out_value_range) { | |||||
output_value_range_list.push_back(pair.first); | |||||
output_value_range_list.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(target_value_range, output_value_range_list); | |||||
auto data_node = subgraphs[0]->FindNode("data1_0"); | |||||
auto data_output_0_desc = data_node->GetOpDesc()->GetOutputDesc(0); | |||||
std::vector<int64_t> target_value_range_list = {1, 1, 1, 1, 1, -1, 1, 224}; | |||||
std::vector<std::pair<int64_t, int64_t>> output_value_range; | |||||
data_output_0_desc.GetValueRange(output_value_range); | |||||
EXPECT_EQ(output_value_range.size(), 4); | |||||
std::vector<int64_t> data_value_range_list; | |||||
for (auto pair : output_value_range) { | |||||
data_value_range_list.push_back(pair.first); | |||||
data_value_range_list.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(data_value_range_list, target_value_range_list); | |||||
data_node = subgraphs[0]->FindNode("data2_0"); | |||||
auto data2_input_0_desc = data_node->GetOpDesc()->GetInputDesc(0); | |||||
std::vector<int64_t> target_value_range_list2 = {1, 100, 1, 10}; | |||||
out_value_range.clear(); | |||||
data2_input_0_desc.GetValueRange(out_value_range); | |||||
EXPECT_EQ(out_value_range.size(), 2); | |||||
data_value_range_list.clear(); | |||||
for (auto pair : out_value_range) { | |||||
data_value_range_list.push_back(pair.first); | |||||
data_value_range_list.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(data_value_range_list, target_value_range_list2); | |||||
} | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_HasCaeSubgraph_WhenAfterSubgraph) { | |||||
auto builder = ParentGraphBuilder(); | |||||
auto parent_graph = builder.GetGraph(); | |||||
AddCaseSubgraph(parent_graph, 2); | |||||
auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
EXPECT_EQ(subgraphs.size(), 2); | |||||
auto case_node = parent_graph->FindNode("case1"); | |||||
EXPECT_NE(case_node, nullptr); | |||||
InferValueRangePass infer_pass; | |||||
// check after subgraph | |||||
infer_pass.options_[kOptimizeAfterSubGraph] = "yes"; | |||||
EXPECT_EQ(infer_pass.Run(case_node), SUCCESS); | |||||
std::vector<int64_t> out_target_dims = {1, -1, 1, -1}; | |||||
auto case_out = case_node->GetOpDesc()->GetOutputDescPtr(0); | |||||
std::vector<std::pair<int64_t, int64_t>> out_value_range; | |||||
case_out->GetValueRange(out_value_range); | |||||
EXPECT_EQ(out_value_range.size(), 2); | |||||
std::vector<int64_t> output_value_range_list; | |||||
for (auto pair : out_value_range) { | |||||
output_value_range_list.push_back(pair.first); | |||||
output_value_range_list.push_back(pair.second); | |||||
} | |||||
EXPECT_EQ(out_target_dims, output_value_range_list); | |||||
} | |||||
TEST_F(UtestGraphInferValueRangePass, CallRun_HasSubgraph_WhenAfterSubgraph_ForMultiDims) { | |||||
auto builder = ParentGraphBuilder(); | |||||
auto parent_graph = builder.GetGraph(); | |||||
AddCaseSubgraph(parent_graph, 2); | |||||
auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
EXPECT_EQ(subgraphs.size(), 2); | |||||
auto case_node = parent_graph->FindNode("case1"); | |||||
EXPECT_NE(case_node, nullptr); | |||||
InferValueRangePass infer_pass; | |||||
infer_pass.options_[kOptimizeAfterSubGraph] = "yes"; | |||||
// check after subgraph for multi-batch | |||||
auto set_ret = AttrUtils::SetInt(case_node->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); | |||||
EXPECT_EQ(set_ret, true); | |||||
EXPECT_EQ(infer_pass.Run(case_node), GRAPH_FAILED); | |||||
} | |||||
} // namespace ge |