| @@ -297,7 +297,9 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/hccl_continuous_memcpy_pass.cc" | "graph/passes/hccl_continuous_memcpy_pass.cc" | ||||
| "graph/passes/identity_pass.cc" | "graph/passes/identity_pass.cc" | ||||
| "graph/passes/ref_identity_delete_op_pass.cc" | "graph/passes/ref_identity_delete_op_pass.cc" | ||||
| "graph/passes/infer_base_pass.cc" | |||||
| "graph/passes/infershape_pass.cc" | "graph/passes/infershape_pass.cc" | ||||
| "graph/passes/infer_value_range_pass.cc" | |||||
| "graph/passes/iterator_op_pass.cc" | "graph/passes/iterator_op_pass.cc" | ||||
| "graph/passes/link_gen_mask_nodes_pass.cc" | "graph/passes/link_gen_mask_nodes_pass.cc" | ||||
| "graph/passes/merge_pass.cc" | "graph/passes/merge_pass.cc" | ||||
| @@ -546,7 +548,9 @@ set(INFER_SRC_LIST | |||||
| "graph/passes/shape_operate_op_remove_pass.cc" | "graph/passes/shape_operate_op_remove_pass.cc" | ||||
| "graph/passes/assert_pass.cc" | "graph/passes/assert_pass.cc" | ||||
| "graph/passes/dropout_pass.cc" | "graph/passes/dropout_pass.cc" | ||||
| "graph/passes/infer_base_pass.cc" | |||||
| "graph/passes/infershape_pass.cc" | "graph/passes/infershape_pass.cc" | ||||
| "graph/passes/infer_value_range_pass.cc" | |||||
| "graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
| "graph/passes/permute_pass.cc" | "graph/passes/permute_pass.cc" | ||||
| "graph/passes/ctrl_edge_transfer_pass.cc" | "graph/passes/ctrl_edge_transfer_pass.cc" | ||||
| @@ -20,35 +20,9 @@ | |||||
| #include "graph/operator_factory.h" | #include "graph/operator_factory.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "init/gelib.h" | |||||
| namespace ge { | namespace ge { | ||||
| const int64_t kStartCallNum = 1; | const int64_t kStartCallNum = 1; | ||||
| const std::string kKernelLibName = "aicpu_tf_kernel"; | |||||
| // tf_kernel.json opsFlag config | |||||
| const std::string kOpsFlagClose = "0"; | |||||
| Status RunOpKernelWithCheck(NodePtr &node, | |||||
| const vector<ConstGeTensorPtr> &inputs, | |||||
| std::vector<GeTensorPtr> &outputs) { | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized or is finalized."); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kKernelLibName); | |||||
| if (kernel_info == nullptr) { | |||||
| GELOGE(FAILED, "[Get][OpsKernelInfoStore] %s failed", kKernelLibName.c_str()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| std::string ops_flag; | |||||
| kernel_info->opsFlagCheck(*node, ops_flag); | |||||
| if (ops_flag == kOpsFlagClose) { | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| return FoldingPass::RunOpKernel(node, inputs, outputs); | |||||
| } | |||||
| const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { | const map<string, pair<uint64_t, uint64_t>> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { | ||||
| return statistic_of_ge_constant_folding_; | return statistic_of_ge_constant_folding_; | ||||
| @@ -81,7 +55,7 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { | |||||
| vector<GeTensorPtr> outputs; | vector<GeTensorPtr> outputs; | ||||
| // Statistic of ge constant folding kernel | // Statistic of ge constant folding kernel | ||||
| uint64_t start_time = GetCurrentTimestamp(); | uint64_t start_time = GetCurrentTimestamp(); | ||||
| auto ret = RunOpKernelWithCheck(node, inputs, outputs); | |||||
| auto ret = FoldingPass::RunOpKernelWithCheck(node, inputs, outputs); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| auto op_kernel = folding_pass::GetKernelByType(node); | auto op_kernel = folding_pass::GetKernelByType(node); | ||||
| if (op_kernel == nullptr) { | if (op_kernel == nullptr) { | ||||
| @@ -29,7 +29,7 @@ | |||||
| #include "inc/kernel_factory.h" | #include "inc/kernel_factory.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "ge_local_engine/engine/host_cpu_engine.h" | #include "ge_local_engine/engine/host_cpu_engine.h" | ||||
| #include "init/gelib.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace folding_pass { | namespace folding_pass { | ||||
| @@ -59,6 +59,9 @@ bool IsNoNeedConstantFolding(const NodePtr &node) { | |||||
| } // namespace folding_pass | } // namespace folding_pass | ||||
| namespace { | namespace { | ||||
| const std::string kKernelLibName = "aicpu_tf_kernel"; | |||||
| const std::string kOpsFlagClose = "0"; | |||||
| IndexsToAnchors GetIndexAndPeerInDataAnchors(NodePtr &node) { | IndexsToAnchors GetIndexAndPeerInDataAnchors(NodePtr &node) { | ||||
| IndexsToAnchors indexes_to_anchors; | IndexsToAnchors indexes_to_anchors; | ||||
| for (auto &out_anchor : node->GetAllOutDataAnchors()) { | for (auto &out_anchor : node->GetAllOutDataAnchors()) { | ||||
| @@ -129,6 +132,27 @@ Status FoldingPass::RunOpKernel(NodePtr &node, | |||||
| return HostCpuEngine::GetInstance().Run(node, inputs, outputs); | return HostCpuEngine::GetInstance().Run(node, inputs, outputs); | ||||
| } | } | ||||
| Status FoldingPass::RunOpKernelWithCheck(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, | |||||
| std::vector<GeTensorPtr> &outputs) { | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized or is finalized."); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kKernelLibName); | |||||
| if (kernel_info == nullptr) { | |||||
| GELOGE(FAILED, "[Get][OpsKernelInfoStore] %s failed", kKernelLibName.c_str()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| std::string ops_flag; | |||||
| kernel_info->opsFlagCheck(*node, ops_flag); | |||||
| if (ops_flag == kOpsFlagClose) { | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| return FoldingPass::RunOpKernel(node, inputs, outputs); | |||||
| } | |||||
| Status FoldingPass::Folding(NodePtr &node, vector<GeTensorPtr> &outputs) { | Status FoldingPass::Folding(NodePtr &node, vector<GeTensorPtr> &outputs) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GELOGD("begin folding node:%s", node->GetName().c_str()); | GELOGD("begin folding node:%s", node->GetName().c_str()); | ||||
| @@ -36,6 +36,9 @@ using IndexsToAnchors = std::map<int, std::vector<InDataAnchorPtr>>; | |||||
| class FoldingPass : public BaseNodePass { | class FoldingPass : public BaseNodePass { | ||||
| public: | public: | ||||
| static Status RunOpKernel(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, vector<GeTensorPtr> &outputs); | static Status RunOpKernel(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, vector<GeTensorPtr> &outputs); | ||||
| static Status RunOpKernelWithCheck(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, | |||||
| std::vector<GeTensorPtr> &outputs); | |||||
| protected: | protected: | ||||
| Status Folding(NodePtr &node, vector<GeTensorPtr> &outputs); | Status Folding(NodePtr &node, vector<GeTensorPtr> &outputs); | ||||
| private: | private: | ||||
| @@ -0,0 +1,676 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "infer_base_pass.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/debug/ge_util.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| string Serial(const vector<int64_t> &dims) { | |||||
| string serial_string; | |||||
| serial_string += "["; | |||||
| for (int64_t dim : dims) { | |||||
| serial_string += std::to_string(dim) + " "; | |||||
| } | |||||
| serial_string += "]"; | |||||
| return serial_string; | |||||
| } | |||||
| void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||||
| desc_str += "["; | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| (void)desc->GetShapeRange(shape_range); | |||||
| for (const auto &pair : shape_range) { | |||||
| desc_str += "{"; | |||||
| desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||||
| desc_str += "},"; | |||||
| } | |||||
| desc_str += "]"; | |||||
| shape_range.clear(); | |||||
| (void)desc->GetOriginShapeRange(shape_range); | |||||
| for (const auto &pair : shape_range) { | |||||
| desc_str += ",{"; | |||||
| desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||||
| desc_str += "},"; | |||||
| } | |||||
| } | |||||
| void SerialValueRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||||
| desc_str += "["; | |||||
| std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
| (void)desc->GetValueRange(value_range); | |||||
| for (const auto &pair : value_range) { | |||||
| desc_str += "{"; | |||||
| desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||||
| desc_str += "},"; | |||||
| } | |||||
| desc_str += "]"; | |||||
| } | |||||
| graphStatus FindSubgraphDataAndNetoutput(const ComputeGraphPtr &sub_graph, NodePtr &netoutput, const ConstNodePtr &node, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) { | |||||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||||
| auto sub_node = sub_nodes.at(i - 1); | |||||
| if (sub_node->GetType() == NETOUTPUT) { | |||||
| netoutput = sub_node; | |||||
| } | |||||
| if (sub_node->GetType() == DATA) { | |||||
| if (sub_node->GetOpDesc() == nullptr) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int ref_i; | |||||
| if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
| REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) { | |||||
| REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!", | |||||
| sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!", | |||||
| sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| Status InferBasePass::Run(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| bool need_infer = NeedInfer(node); | |||||
| if (!need_infer) { | |||||
| GELOGD("Node %s does not need to infer.", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::set<NodePtr> changed_nodes; | |||||
| auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| (void)AnalyzeFailedInfo(node); | |||||
| return GE_GRAPH_INFERSHAPE_FAILED; | |||||
| } | |||||
| /* | |||||
| * we will use changed nodes to do repass for control_ops. | |||||
| * AddChangedNodesImmediateRepass(changed_nodes); | |||||
| */ | |||||
| auto status = DoRepassForLoopNode(node); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "repass failed. node: %s", node->GetName().c_str()); | |||||
| return GE_GRAPH_INFERSHAPE_FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool InferBasePass::NeedInfer(const NodePtr &node) { return true; } | |||||
| void InferBasePass::AnalyzeFailedInfo(const NodePtr &node) { /* Analyze and select failed info*/ } | |||||
| Status InferBasePass::DoRepassForLoopNode(NodePtr &node) { return SUCCESS; } | |||||
| graphStatus InferBasePass::UpdatePeerInputs(NodePtr &node) { return GRAPH_SUCCESS; } | |||||
| void InferBasePass::AddChangedNodesImmediateRepass(std::set<NodePtr> &changed_nodes) { | |||||
| for (const auto &node_ele : changed_nodes) { | |||||
| AddImmediateRePassNode(node_ele); | |||||
| } | |||||
| } | |||||
| graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | |||||
| auto ret = GRAPH_SUCCESS; | |||||
| bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||||
| auto opdesc = node->GetOpDesc(); | |||||
| // some op can not infershape twice such as aipp | |||||
| bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified"); | |||||
| if (need_update_input) { | |||||
| ret = UpdateCurOpInputDesc(node); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "update op input_desc failed! ret:%d, node:%s", ret, node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Update][OpInputDesc] failed! ret:%d", ret); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| bool contain_subgraph = ContainsSubgraph(node); | |||||
| if (contain_subgraph && before_subgraph) { | |||||
| ret = UpdateTensorDescToSubgraphData(node, changed_nodes); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| ret = Infer(node); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (contain_subgraph && !before_subgraph) { | |||||
| ret = UpdateTensorDescToParentNode(node, changed_nodes); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| ret = UpdatePeerInputs(node); | |||||
| return ret; | |||||
| } | |||||
| graphStatus InferBasePass::UpdateCurOpInputDesc(const NodePtr &node_ptr) { | |||||
| for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { | |||||
| auto in_idx = in_anchor->GetIdx(); | |||||
| auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| int peer_out_idx = peer_out_data_anchor->GetIdx(); | |||||
| auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx)); | |||||
| // check shape and dtype continuity. do not stop process | |||||
| auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx)); | |||||
| if (in_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto in_shape = in_desc->MutableShape().GetDims(); | |||||
| auto in_dtype = in_desc->GetDataType(); | |||||
| auto peer_out_shape = peer_out_desc->MutableShape().GetDims(); | |||||
| auto peer_out_dtype = peer_out_desc->GetDataType(); | |||||
| if (peer_out_dtype != in_dtype) { | |||||
| GELOGW( | |||||
| "current node [%s] [%d]\'th in_dtype is [%s].peer output node [%s] [%d]\'th " | |||||
| "output_dtype is [%s].The two dtype should be same! Please check graph and fix it", | |||||
| node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), | |||||
| peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); | |||||
| } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { | |||||
| string in_shape_str = Serial(in_shape); | |||||
| string peer_out_shape_str = Serial(peer_out_shape); | |||||
| GELOGW( | |||||
| "current node [%s] [%d]\'th in_shape is [%s].peer output node [%s] [%d]\'th " | |||||
| "output_shape is [%s].The two shape should be same! Please check graph and fix it", | |||||
| node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx, | |||||
| peer_out_shape_str.c_str()); | |||||
| } | |||||
| // refresh current node input desc | |||||
| bool output_changed = false; | |||||
| (void)UpdateInputDescAttr(peer_out_desc, in_desc, output_changed); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus InferBasePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | |||||
| changed = false; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool InferBasePass::ContainsSubgraph(const NodePtr &node) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (sub_graph_names.empty()) { | |||||
| return false; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
| if (root_graph == nullptr) { | |||||
| return false; | |||||
| } | |||||
| for (const auto &name : sub_graph_names) { | |||||
| if (name.empty()) { | |||||
| continue; | |||||
| } | |||||
| auto sub_graph = root_graph->GetSubgraph(name); | |||||
| if (sub_graph != nullptr) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::vector<ComputeGraphPtr> InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) { | |||||
| std::vector<ComputeGraphPtr> cur_node_subgraph; | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (sub_graph_names.empty()) { | |||||
| return cur_node_subgraph; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
| for (const auto &name : sub_graph_names) { | |||||
| if (name.empty()) { | |||||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto sub_graph = root_graph->GetSubgraph(name); | |||||
| if (sub_graph == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||||
| GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| cur_node_subgraph.emplace_back(sub_graph); | |||||
| } | |||||
| return cur_node_subgraph; | |||||
| } | |||||
| graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes) { | |||||
| // if infer again, update output of while into subgraph data node | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| for (const auto &sub_graph : GetCurNodeSubgraphs(node)) { | |||||
| for (const auto &node_sub : sub_graph->GetDirectNode()) { | |||||
| if (node_sub->GetType() != DATA) { | |||||
| continue; | |||||
| } | |||||
| auto name = sub_graph->GetName(); | |||||
| int ref_i; | |||||
| auto data_opdesc = node_sub->GetOpDesc(); | |||||
| if (data_opdesc == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
| REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute", | |||||
| name.c_str(), node->GetName().c_str()); | |||||
| GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { | |||||
| continue; | |||||
| } | |||||
| auto input_desc = op_desc->MutableInputDesc(ref_i); | |||||
| if (input_desc == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "The ref index(%d) on the data %s on the sub graph %s " | |||||
| "parent node %s are incompatible, inputs num %u", | |||||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||||
| node->GetAllInDataAnchorsSize()); | |||||
| GE_LOGE( | |||||
| "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s " | |||||
| "parent node %s are incompatible, inputs num %u", | |||||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), | |||||
| node->GetName().c_str()); | |||||
| // if need infer again, refresh subgraph input with output | |||||
| bool is_infer_again = false; | |||||
| AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, is_infer_again); | |||||
| if (is_infer_again) { | |||||
| input_desc = op_desc->MutableOutputDesc(ref_i); | |||||
| if (input_desc == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "The ref index(%d) on the data %s on the subgraph %s " | |||||
| "parent node %s are incompatible, outputs num %u.", | |||||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||||
| node->GetAllOutDataAnchorsSize()); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "[Call][MutableOutputDesc] The ref index(%d) on the data %s on the subgraph %s " | |||||
| "parent node %s are incompatible, outputs num %u.", | |||||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||||
| node->GetAllOutDataAnchorsSize()); | |||||
| } | |||||
| GELOGD("Update input desc of data %s on the sub graph %s of node %s,output idx: %d from [%s] to [%s]", | |||||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), ref_i, | |||||
| data_opdesc->GetInputDescPtr(0)->GetShape().ToString().c_str(), | |||||
| input_desc->GetShape().ToString().c_str()); | |||||
| } | |||||
| auto data_input_desc = data_opdesc->MutableInputDesc(0); | |||||
| auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s", | |||||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
| GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(), | |||||
| name.c_str(), node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| bool input_changed = TensorDescChanged(input_desc, data_input_desc); | |||||
| auto data_output_desc = data_opdesc->MutableOutputDesc(0); | |||||
| ret = data_opdesc->UpdateOutputDesc(0, *input_desc); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s", | |||||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
| GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed", | |||||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| bool output_changed = TensorDescChanged(input_desc, data_output_desc); | |||||
| if (input_changed || output_changed) { | |||||
| changed_nodes.insert(node_sub); | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes) { | |||||
| std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize()); | |||||
| std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize()); | |||||
| for (const auto &sub_graph : GetCurNodeSubgraphs(node)) { | |||||
| auto name = sub_graph->GetName(); | |||||
| NodePtr netoutput = nullptr; | |||||
| auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (netoutput == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto netoutput_opdesc = netoutput->GetOpDesc(); | |||||
| if (netoutput_opdesc == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", | |||||
| name.c_str(), node->GetName().c_str()); | |||||
| GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { | |||||
| auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); | |||||
| if (edge_desc == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Invalid NetOutput node on sub graph %s, parent node %s, " | |||||
| "can not find input tensor %d", | |||||
| name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); | |||||
| GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", | |||||
| name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(), | |||||
| edge_desc->GetShape().GetDimNum()); | |||||
| int ref_i; | |||||
| if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
| // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||||
| continue; | |||||
| } | |||||
| GELOGI("Parent node index of edge desc is %d", ref_i); | |||||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ref_out_tensors[ref_i].emplace_back(*edge_desc); | |||||
| } | |||||
| } | |||||
| if (node->GetType() == WHILE) { | |||||
| return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes); | |||||
| } | |||||
| return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes); | |||||
| } | |||||
| graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_data_tensors, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||||
| std::set<NodePtr> &changed_nodes) { | |||||
| GELOGD("Enter update parent node shape for class while op process"); | |||||
| if (ref_data_tensors.size() != ref_out_tensors.size()) { | |||||
| REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!", | |||||
| node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(), | |||||
| ref_out_tensors.size()); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!", | |||||
| node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (size_t i = 0; i < ref_data_tensors.size(); i++) { | |||||
| if (ref_out_tensors[i].size() != 1) { | |||||
| REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!"); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| bool need_infer_again = false; | |||||
| // check input and output | |||||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||||
| if (ref_out_tensors[i].empty()) { | |||||
| continue; | |||||
| } | |||||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||||
| auto out_shape = ref_out_tensor.MutableShape(); | |||||
| vector<std::pair<int64_t, int64_t>> data_shape_range; | |||||
| // ref_i's data and output tensor shape should be same | |||||
| for (auto &tensor : ref_data_tensors[i]) { | |||||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output", | |||||
| node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.", | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto data_shape = tensor.MutableShape(); | |||||
| // input is dynamic, here use dim_num | |||||
| if (data_shape.GetDims() != out_shape.GetDims()) { | |||||
| GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.", | |||||
| node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str()); | |||||
| if (data_shape.GetDimNum() != out_shape.GetDimNum()) { | |||||
| ref_out_tensor.SetUnknownDimNumShape(); | |||||
| } else { | |||||
| for (size_t j = 0; j < data_shape.GetDimNum(); ++j) { | |||||
| if (data_shape.GetDim(j) != out_shape.GetDim(j)) { | |||||
| if (data_shape.GetDim(j) != UNKNOWN_DIM) { | |||||
| // if input data is fix shape, output is different, need_infer_again | |||||
| need_infer_again = true; | |||||
| } | |||||
| data_shape.SetDim(j, UNKNOWN_DIM); | |||||
| } | |||||
| // set shape rang of while, if dim is unknown ,set shape range as {1,-1} | |||||
| if (data_shape.GetDim(j) == UNKNOWN_DIM) { | |||||
| data_shape_range.emplace_back(std::make_pair(1, UNKNOWN_DIM)); | |||||
| } else { | |||||
| data_shape_range.emplace_back(std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j))); | |||||
| } | |||||
| } | |||||
| ref_out_tensor.SetShape(data_shape); | |||||
| ref_out_tensor.SetShapeRange(data_shape_range); | |||||
| } | |||||
| } | |||||
| } | |||||
| auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); | |||||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||||
| bool output_changed = TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc); | |||||
| if (output_changed) { | |||||
| changed_nodes.insert(node); | |||||
| } | |||||
| } | |||||
| AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_infer_again); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||||
| std::set<NodePtr> &changed_nodes) { | |||||
| // check sub_graph shape. Get max for update. | |||||
| for (size_t i = 0; i < ref_out_tensors.size(); ++i) { | |||||
| if (ref_out_tensors[i].empty()) { | |||||
| continue; | |||||
| } | |||||
| int64_t max_size = 0; | |||||
| size_t max_shape_index = 0; | |||||
| auto &ref_out_tensor = ref_out_tensors[i].at(0); | |||||
| for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { | |||||
| auto &tensor = ref_out_tensors[i].at(j); | |||||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output", | |||||
| node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output", | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto shape = tensor.MutableShape(); | |||||
| int64_t size = 1; | |||||
| for (auto dim : shape.GetDims()) { | |||||
| if (dim != 0 && INT64_MAX / dim < size) { | |||||
| REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(), | |||||
| node->GetName().c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| size *= dim; | |||||
| } | |||||
| if (size > max_size) { | |||||
| max_size = size; | |||||
| max_shape_index = j; | |||||
| } | |||||
| } | |||||
| auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); | |||||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); | |||||
| bool output_changed = | |||||
| TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensors[i].at(max_shape_index)), output_desc); | |||||
| if (output_changed) { | |||||
| changed_nodes.insert(node); | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||||
| std::set<NodePtr> &changed_nodes) { | |||||
| GELOGD("Enter update parent node shape for class branch op process"); | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | |||||
| return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes); | |||||
| } | |||||
| // check sub_graph shape.If not same ,do unknown shape process | |||||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||||
| if (ref_out_tensors[i].empty()) { | |||||
| continue; | |||||
| } | |||||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||||
| ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); | |||||
| for (auto &tensor : ref_out_tensors[i]) { | |||||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s", | |||||
| node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto shape = tensor.MutableShape(); | |||||
| if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||||
| GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||||
| shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||||
| ref_out_tensor_shape = GeShape(UNKNOWN_RANK); | |||||
| break; | |||||
| } | |||||
| for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { | |||||
| if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { | |||||
| continue; | |||||
| } | |||||
| GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), | |||||
| i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||||
| (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); | |||||
| } | |||||
| } | |||||
| auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); | |||||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||||
| bool output_changed = | |||||
| TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc); | |||||
| if (output_changed) { | |||||
| changed_nodes.insert(node); | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) { | |||||
| if (!IsLogEnable(GE, DLOG_DEBUG)) { | |||||
| return; | |||||
| } | |||||
| if (node == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid"); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] node is null"); | |||||
| return; | |||||
| } | |||||
| ge::OpDescPtr op_desc = node->GetOpDesc(); | |||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid"); | |||||
| GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return ); | |||||
| std::stringstream ss; | |||||
| ss << "{"; | |||||
| int32_t in_idx = 0; | |||||
| int32_t out_idx = 0; | |||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
| if (input_desc == nullptr) { | |||||
| in_idx++; | |||||
| continue; | |||||
| } | |||||
| if (in_idx > 0) { | |||||
| ss << " "; | |||||
| } | |||||
| ss << "input_" << in_idx << " " | |||||
| << "tensor: ["; | |||||
| ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; | |||||
| ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; | |||||
| ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; | |||||
| ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; | |||||
| ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; | |||||
| ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; | |||||
| string range_str; | |||||
| SerialShapeRange(input_desc, range_str); | |||||
| ss << "(shape_range:" << range_str << "),"; | |||||
| string value_range_str; | |||||
| SerialValueRange(input_desc, value_range_str); | |||||
| ss << "(value_range:" << value_range_str << ")]"; | |||||
| in_idx++; | |||||
| } | |||||
| for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
| if (output_desc == nullptr) { | |||||
| out_idx++; | |||||
| continue; | |||||
| } | |||||
| ss << " "; | |||||
| ss << "output_" << out_idx << " " | |||||
| << "tensor: ["; | |||||
| ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),"; | |||||
| ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),"; | |||||
| ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),"; | |||||
| ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),"; | |||||
| ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),"; | |||||
| ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),"; | |||||
| string range_str; | |||||
| SerialShapeRange(output_desc, range_str); | |||||
| ss << "(shape_range:" << range_str << "),"; | |||||
| string value_range_str; | |||||
| SerialValueRange(output_desc, value_range_str); | |||||
| ss << "(value_range:" << value_range_str << ")]"; | |||||
| out_idx++; | |||||
| } | |||||
| ss << "}"; | |||||
| GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str()); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_INFER_BASE_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_INFER_BASE_PASS_H_ | |||||
| #include "graph/passes/base_pass.h" | |||||
| namespace ge { | |||||
| class InferBasePass : public BaseNodePass { | |||||
| public: | |||||
| Status Run(NodePtr &node) override; | |||||
| graphStatus InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes); | |||||
| void PrintInOutTensorShape(const NodePtr &node, const std::string &phase); | |||||
| protected: | |||||
| virtual graphStatus Infer(NodePtr &node) = 0; | |||||
| virtual bool TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) = 0; | |||||
| virtual graphStatus UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed); | |||||
| virtual bool NeedInfer(const NodePtr &node); | |||||
| virtual void AnalyzeFailedInfo(const NodePtr &node); | |||||
| virtual Status DoRepassForLoopNode(NodePtr &node); // only for infershape, will be deleted | |||||
| virtual graphStatus UpdatePeerInputs(NodePtr &node); // only for infershape, will be deleted | |||||
| private: | |||||
| void AddChangedNodesImmediateRepass(std::set<NodePtr> &changed_nodes); | |||||
| graphStatus UpdateCurOpInputDesc(const NodePtr &node_ptr); | |||||
| bool ContainsSubgraph(const NodePtr &node); | |||||
| std::vector<ComputeGraphPtr> GetCurNodeSubgraphs(const NodePtr &node); | |||||
| graphStatus UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes); | |||||
| graphStatus UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes); | |||||
| graphStatus UpdateParentNodeForWhile(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||||
| std::set<NodePtr> &changed_nodes); | |||||
| graphStatus UpdateParentNodeForBranch(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||||
| std::set<NodePtr> &changed_nodes); | |||||
| graphStatus UpdateOutputForMultiBatch(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||||
| std::set<NodePtr> &changed_nodes); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_INFER_BASE_PASS_H_ | |||||
| @@ -0,0 +1,388 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/infer_value_range_pass.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/operator_factory_impl.h" | |||||
| #include "graph/passes/folding_pass.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "init/gelib.h" | |||||
| using std::unique_ptr; | |||||
| namespace ge { | |||||
| namespace { | |||||
| #define GET_DATA_BY_DTYPE(DTYPE, TYPE) \ | |||||
| case (DTYPE): \ | |||||
| ConstructValueRange<TYPE>(lower_tensor, higher_tensor, output_tensor_value_range); \ | |||||
| break; | |||||
| Status RunCpuKernelForValueRange(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, | |||||
| std::vector<GeTensorPtr> &outputs) { | |||||
| // should use RunOpKernelWithCheck, RunOpKernel for ut test | |||||
| auto ret = FoldingPass::RunOpKernel(node, inputs, outputs); | |||||
| if (ret != SUCCESS) { | |||||
| auto op_kernel = folding_pass::GetKernelByType(node); | |||||
| if (op_kernel == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(), | |||||
| node->GetType().c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), | |||||
| node->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Calculate for node %s failed in constant folding", node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| graphStatus InferValueRangePass::Infer(NodePtr &node) { | |||||
| PrintInOutTensorShape(node, "before_infer_value_range"); | |||||
| auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); | |||||
| // Use registered func to calculate value range | |||||
| if (!infer_value_range_param.use_cpu_kernel) { | |||||
| if (infer_value_range_param.infer_value_func == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "The registered func to infer value range is nullptr."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||||
| auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Node %s call infer value range function failed.", node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node: %s.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // Use CPU kernel func to calculate value range | |||||
| return ConstructInputAndInferValueRange(node); | |||||
| } | |||||
| bool InferValueRangePass::NeedInfer(const NodePtr &node) { | |||||
| auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); | |||||
| if (!infer_value_range_param.is_initialized) { | |||||
| GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.", | |||||
| node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) { | |||||
| // Only do infer for node that all inputs are dynamic, such as shape | |||||
| if (InputIsDynamic(node)) { | |||||
| return true; | |||||
| } | |||||
| GELOGD("Node %s register func to infer value range and when_call is INPUT_IS_DYNAMIC, but check input failed.", | |||||
| node->GetName().c_str()); | |||||
| } else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) { | |||||
| // Only do infer for node that all inputs have value_range or node type of inputs is constant/const | |||||
| if (InputIsConstOrHasValueRange(node)) { | |||||
| return true; | |||||
| } | |||||
| GELOGD("Node %s register func to infer value range and when_call is INPUT_HAS_VALUE_RANGE, but check input failed.", | |||||
| node->GetName().c_str()); | |||||
| } | |||||
| GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| bool InferValueRangePass::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { | |||||
| bool changed = false; | |||||
| std::vector<std::pair<int64_t, int64_t>> src_value_range; | |||||
| std::vector<std::pair<int64_t, int64_t>> dst_value_range; | |||||
| (void)src->GetValueRange(src_value_range); | |||||
| (void)dst->GetValueRange(dst_value_range); | |||||
| if (src_value_range != dst_value_range) { | |||||
| changed = true; | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| graphStatus InferValueRangePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | |||||
| changed = false; | |||||
| std::vector<std::pair<int64_t, int64_t>> src_value_range; | |||||
| std::vector<std::pair<int64_t, int64_t>> dst_value_range; | |||||
| (void)src->GetValueRange(src_value_range); | |||||
| (void)dst->GetValueRange(dst_value_range); | |||||
| if (src_value_range != dst_value_range) { | |||||
| changed = true; | |||||
| } | |||||
| dst->SetValueRange(src_value_range); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| void InferValueRangePass::AnalyzeFailedInfo(const NodePtr &node) { | |||||
| REPORT_CALL_ERROR("E19999", "Infer value range for node:%s(%s) failed.", node->GetName().c_str(), | |||||
| node->GetType().c_str()); | |||||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infer value range failed. node: %s", node->GetName().c_str()); | |||||
| } | |||||
| bool InferValueRangePass::InputIsDynamic(const NodePtr &node) { | |||||
| bool input_is_dynamic = false; | |||||
| auto cur_op_desc = node->GetOpDesc(); | |||||
| for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { | |||||
| auto dims = input_desc->GetShape().GetDims(); | |||||
| for (auto dim : dims) { | |||||
| if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||||
| input_is_dynamic = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| return input_is_dynamic; | |||||
| } | |||||
| bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) { | |||||
| bool input_is_const_or_has_value_range = true; | |||||
| auto cur_op_desc = node->GetOpDesc(); | |||||
| auto in_data_anchors = node->GetAllInDataAnchors(); | |||||
| for (auto i = 0; i < in_data_anchors.size(); ++i) { | |||||
| auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); | |||||
| if (peer_out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto peer_node = peer_out_anchor->GetOwnerNode(); | |||||
| if (peer_node == nullptr || peer_node->GetOpDesc() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { | |||||
| continue; | |||||
| } | |||||
| const auto &input_desc = cur_op_desc->GetInputDesc(i); | |||||
| std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
| (void)input_desc.GetValueRange(value_range); | |||||
| if (value_range.empty()) { | |||||
| int peer_out_idx = peer_out_anchor->GetIdx(); | |||||
| auto peer_out_desc = peer_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx)); | |||||
| (void)peer_out_desc->GetValueRange(value_range); | |||||
| if (value_range.empty()) { | |||||
| input_is_const_or_has_value_range = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| return input_is_const_or_has_value_range; | |||||
| } | |||||
| template <typename T> | |||||
| graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr) { | |||||
| std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
| (void)tensor_desc.GetValueRange(value_range); | |||||
| if (value_range.size() != tensor_desc.GetShape().GetShapeSize()) { | |||||
| REPORT_INNER_ERROR("E19999", "Value range of input %s is invalid.", tensor_desc.GetName().c_str()); | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Value range of input %s is invalid.", tensor_desc.GetName().c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto value_range_data_num = value_range.size(); | |||||
| unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]()); | |||||
| if (buf == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "New buf failed"); | |||||
| GELOGE(MEMALLOC_FAILED, "new buf failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (auto j = 0; j < value_range_data_num; ++j) { | |||||
| auto value_range_j = use_floor_value ? value_range[j].first : value_range[j].second; | |||||
| buf[j] = static_cast<T>(value_range_j); | |||||
| } | |||||
| if (output_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), value_range_data_num * sizeof(T)) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "set data failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus InferValueRangePass::ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr) { | |||||
| graphStatus ret = GRAPH_SUCCESS; | |||||
| auto data_type = tensor_desc.GetDataType(); | |||||
| output_ptr->MutableTensorDesc().SetDataType(data_type); | |||||
| switch (data_type) { | |||||
| case DT_FLOAT: | |||||
| ret = ConstructData<float>(tensor_desc, use_floor_value, output_ptr); | |||||
| break; | |||||
| case DT_DOUBLE: | |||||
| ret = ConstructData<double>(tensor_desc, use_floor_value, output_ptr); | |||||
| break; | |||||
| case DT_UINT8: | |||||
| ret = ConstructData<uint8_t>(tensor_desc, use_floor_value, output_ptr); | |||||
| break; | |||||
| case DT_INT8: | |||||
| ret = ConstructData<int8_t>(tensor_desc, use_floor_value, output_ptr); | |||||
| break; | |||||
| case DT_UINT16: | |||||
| ret = ConstructData<uint16_t>(tensor_desc, use_floor_value, output_ptr); | |||||
| break; | |||||
| case DT_INT16: | |||||
| ret = ConstructData<int16_t>(tensor_desc, use_floor_value, output_ptr); | |||||
| break; | |||||
| case DT_INT32: | |||||
| ret = ConstructData<int32_t>(tensor_desc, use_floor_value, output_ptr); | |||||
| break; | |||||
| case DT_INT64: | |||||
| ret = ConstructData<int64_t>(tensor_desc, use_floor_value, output_ptr); | |||||
| break; | |||||
| default: | |||||
| GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| ret = GRAPH_FAILED; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| vector<ConstGeTensorPtr> InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) { | |||||
| vector<ConstGeTensorPtr> input_tensors; | |||||
| auto cur_op_desc = node->GetOpDesc(); | |||||
| auto in_data_anchors = node->GetAllInDataAnchors(); | |||||
| for (auto i = 0; i < in_data_anchors.size(); ++i) { | |||||
| auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); | |||||
| if (peer_out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto peer_node = peer_out_anchor->GetOwnerNode(); | |||||
| if (peer_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| // construct input tensor by constant node | |||||
| if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { | |||||
| vector<GeTensorPtr> const_weight = OpDescUtils::MutableWeights(peer_node); | |||||
| if (const_weight.empty()) { | |||||
| REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight is empty, node: %s(%s)", | |||||
| peer_node->GetName().c_str(), peer_node->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(), | |||||
| peer_node->GetType().c_str()); | |||||
| return vector<ConstGeTensorPtr>(); | |||||
| } | |||||
| // const/constant op has only one weight | |||||
| if (const_weight.at(0) == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight of constant is null, node: %s(%s)", | |||||
| peer_node->GetName().c_str(), peer_node->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight of constant is null, node name: %s(%s)", | |||||
| peer_node->GetName().c_str(), peer_node->GetType().c_str()); | |||||
| return vector<ConstGeTensorPtr>(); | |||||
| } | |||||
| input_tensors.push_back(const_weight.at(0)); | |||||
| continue; | |||||
| } | |||||
| // construct input tensor by boundary of value range | |||||
| const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i); | |||||
| GeTensorPtr tmp_tensor_ptr = MakeShared<GeTensor>(input_tensor_desc); | |||||
| if (tmp_tensor_ptr == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Make shared failed"); | |||||
| GELOGE(MEMALLOC_FAILED, "Make shared failed"); | |||||
| return vector<ConstGeTensorPtr>(); | |||||
| } | |||||
| auto ret = ConstructDataByType(input_tensor_desc, use_floor_value, tmp_tensor_ptr); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "Input %s construct input tensor by boundary of value range failed.", | |||||
| input_tensor_desc.GetName().c_str()); | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Input %s construct input tensor by boundary of value range failed.", | |||||
| input_tensor_desc.GetName().c_str()); | |||||
| return vector<ConstGeTensorPtr>(); | |||||
| } | |||||
| input_tensors.push_back(tmp_tensor_ptr); | |||||
| } | |||||
| return input_tensors; | |||||
| } | |||||
| graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) { | |||||
| auto inputs = ConstructInputTensors(node, true); | |||||
| if (inputs.empty()) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| vector<GeTensorPtr> outputs_lower; | |||||
| auto ret = RunCpuKernelForValueRange(node, inputs, outputs_lower); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| inputs = ConstructInputTensors(node, false); | |||||
| if (inputs.empty()) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| vector<GeTensorPtr> outputs_higher; | |||||
| ret = RunCpuKernelForValueRange(node, inputs, outputs_higher); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // construct value range from output tensor | |||||
| OpDescPtr node_desc = node->GetOpDesc(); | |||||
| std::vector<std::pair<int64_t, int64_t>> output_tensor_value_range; | |||||
| size_t node_output_desc_size = node_desc->GetOutputsSize(); | |||||
| for (size_t i = 0; i < node_output_desc_size; ++i) { | |||||
| output_tensor_value_range.clear(); | |||||
| auto lower_tensor = outputs_lower[i]; | |||||
| auto lower_tensor_shape_size = lower_tensor->GetTensorDesc().GetShape().GetShapeSize(); | |||||
| auto higher_tensor = outputs_higher[i]; | |||||
| auto higher_tensor_shape_size = higher_tensor->GetTensorDesc().GetShape().GetShapeSize(); | |||||
| auto output_tensor_desc = node_desc->MutableOutputDesc(i); | |||||
| auto output_tensor_shape_size = output_tensor_desc->GetShape().GetShapeSize(); | |||||
| if (output_tensor_shape_size != lower_tensor_shape_size || output_tensor_shape_size != higher_tensor_shape_size) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Value range of output %s is invalid.", output_tensor_desc->GetName().c_str()); | |||||
| } | |||||
| auto data_type = output_tensor_desc->GetDataType(); | |||||
| switch (data_type) { | |||||
| GET_DATA_BY_DTYPE(DT_INT8, int8_t) | |||||
| GET_DATA_BY_DTYPE(DT_INT16, int16_t) | |||||
| GET_DATA_BY_DTYPE(DT_INT32, int32_t) | |||||
| GET_DATA_BY_DTYPE(DT_INT64, int64_t) | |||||
| GET_DATA_BY_DTYPE(DT_UINT8, uint8_t) | |||||
| GET_DATA_BY_DTYPE(DT_UINT16, uint16_t) | |||||
| GET_DATA_BY_DTYPE(DT_UINT32, uint32_t) | |||||
| GET_DATA_BY_DTYPE(DT_UINT64, uint64_t) | |||||
| GET_DATA_BY_DTYPE(DT_FLOAT, float) | |||||
| GET_DATA_BY_DTYPE(DT_DOUBLE, double) | |||||
| default: | |||||
| GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| output_tensor_desc->SetValueRange(output_tensor_value_range); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| template <typename T> | |||||
| void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, | |||||
| std::vector<std::pair<int64_t, int64_t>> &value_range) { | |||||
| auto x = reinterpret_cast<const T *>(left_tensor->GetData().GetData()); | |||||
| auto y = reinterpret_cast<const T *>(right_tensor->GetData().GetData()); | |||||
| for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) { | |||||
| auto left = static_cast<int64_t>(*(x + j)); | |||||
| auto right = static_cast<int64_t>(*(y + j)); | |||||
| value_range.emplace_back(std::make_pair(left, right)); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ | |||||
| #include "graph/passes/infer_base_pass.h" | |||||
| namespace ge { | |||||
| class InferValueRangePass : public InferBasePass { | |||||
| public: | |||||
| graphStatus Infer(NodePtr &node) override; | |||||
| bool TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) override; | |||||
| graphStatus UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; | |||||
| bool NeedInfer(const NodePtr &node) override; | |||||
| void AnalyzeFailedInfo(const NodePtr &node) override; | |||||
| private: | |||||
| bool InputIsDynamic(const NodePtr &node); | |||||
| bool InputIsConstOrHasValueRange(const NodePtr &node); | |||||
| template <typename T> | |||||
| graphStatus ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); | |||||
| graphStatus ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); | |||||
| vector<ConstGeTensorPtr> ConstructInputTensors(const NodePtr &node, bool use_floor_value); | |||||
| template <typename T> | |||||
| void ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, | |||||
| std::vector<std::pair<int64_t, int64_t>> &value_range); | |||||
| graphStatus ConstructInputAndInferValueRange(NodePtr &node); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ | |||||
| @@ -19,15 +19,84 @@ | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "analyzer/analyzer.h" | #include "analyzer/analyzer.h" | ||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "graph/shape_refiner.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "utils/tensor_utils.h" | |||||
| #include "utils/type_utils.h" | |||||
| #include "graph/debug/ge_util.h" | |||||
| #include "graph/operator_factory_impl.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const char *const kPreOpInputShapeRange = "_pre_op_in_range"; | |||||
| thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void InferShapePass::ClearContextMap() { context_map.clear(); } | |||||
| InferenceContextPtr CreateInferenceContextPtr(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | |||||
| const NodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "node is null"); | |||||
| return nullptr; | |||||
| } | |||||
| InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create()); | |||||
| if (inference_context == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Failed to alloc InferenceContext, node:%s", node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Alloc][InferenceContext] failed."); | |||||
| return nullptr; | |||||
| } | |||||
| auto all_in_data_anchors = node->GetAllInDataAnchors(); | |||||
| std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size()); | |||||
| std::vector<std::string> marks; | |||||
| bool has_input_shapes_and_types = false; | |||||
| for (const auto &in_anchor : all_in_data_anchors) { | |||||
| const auto &out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto input_node = out_anchor->GetOwnerNode(); | |||||
| if (input_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto iter = context_map.find(input_node); | |||||
| if (iter != context_map.end()) { | |||||
| const auto &src_context = iter->second; | |||||
| GE_IF_BOOL_EXEC(src_context == nullptr, REPORT_INNER_ERROR("E19999", "src_context is null."); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] src_context is null."); return nullptr); | |||||
| GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(), | |||||
| input_node->GetName().c_str()); | |||||
| for (auto mark : src_context->GetMarks()) { | |||||
| marks.push_back(mark); | |||||
| } | |||||
| auto output_idx = out_anchor->GetIdx(); | |||||
| auto input_idx = in_anchor->GetIdx(); | |||||
| auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); | |||||
| if (output_idx < static_cast<int>(output_shape_and_type.size())) { | |||||
| GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx, | |||||
| node->GetName().c_str(), input_idx); | |||||
| input_shapes_and_types[input_idx] = output_shape_and_type[output_idx]; | |||||
| has_input_shapes_and_types = true; | |||||
| } else { | |||||
| GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, | |||||
| output_shape_and_type.size()); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (has_input_shapes_and_types) { | |||||
| inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); | |||||
| } | |||||
| inference_context->SetMarks(marks); | |||||
| return inference_context; | |||||
| } | |||||
| void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | ||||
| desc_str += "["; | desc_str += "["; | ||||
| @@ -61,7 +130,8 @@ std::string GetInTensorInfoWithString(const ge::NodePtr &node) { | |||||
| if (in_idx > 0) { | if (in_idx > 0) { | ||||
| ss << " "; | ss << " "; | ||||
| } | } | ||||
| ss << "input_" << in_idx << " " << "tensor: ["; | |||||
| ss << "input_" << in_idx << " " | |||||
| << "tensor: ["; | |||||
| ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; | ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; | ||||
| ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; | ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; | ||||
| ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; | ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; | ||||
| @@ -76,28 +146,180 @@ std::string GetInTensorInfoWithString(const ge::NodePtr &node) { | |||||
| return ss.str(); | return ss.str(); | ||||
| } | } | ||||
| Status InferShapePass::Run(NodePtr &node) { | |||||
| // kOptimizeAfterSubGraph exist means after subgraph | |||||
| auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| // select INFERSHAPE failed info | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| auto root_graph = ge::GraphUtils::FindRootGraph(graph); | |||||
| GE_CHECK_NOTNULL(root_graph); | |||||
| analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), | |||||
| analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; | |||||
| (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); | |||||
| (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), | |||||
| root_graph->GetGraphID()); | |||||
| REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", | |||||
| node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", | |||||
| node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||||
| return GE_GRAPH_INFERSHAPE_FAILED; | |||||
| void InferShapePass::AnalyzeFailedInfo(const NodePtr &node) { | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| if (graph == nullptr) { | |||||
| GELOGW("Owner compute graph of node %s is nullptr", node->GetName().c_str()); | |||||
| } | |||||
| auto root_graph = ge::GraphUtils::FindRootGraph(graph); | |||||
| if (root_graph == nullptr) { | |||||
| GELOGW("Root compute graph of node %s is nullptr", node->GetName().c_str()); | |||||
| } | |||||
| analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), analyzer::INFER_SHAPE, node, | |||||
| "InferShapeFailed!"}; | |||||
| (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); | |||||
| (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), root_graph->GetGraphID()); | |||||
| REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", node->GetName().c_str(), | |||||
| node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", | |||||
| node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||||
| } | |||||
| bool InferShapePass::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { | |||||
| bool changed = false; | |||||
| const auto &dst_dims = dst->GetShape().GetDims(); | |||||
| const auto &src_dims = src->GetShape().GetDims(); | |||||
| if (dst_dims != src_dims) { | |||||
| changed = true; | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| graphStatus InferShapePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | |||||
| dst->SetOriginShape(src->GetOriginShape()); | |||||
| dst->SetShape(src->MutableShape()); | |||||
| dst->SetDataType(src->GetDataType()); | |||||
| dst->SetOriginDataType(src->GetOriginDataType()); | |||||
| if (src->MutableShape().GetDims() != UNKNOWN_RANK) { | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| (void)src->GetShapeRange(shape_range); | |||||
| dst->SetShapeRange(shape_range); | |||||
| } | |||||
| std::vector<int64_t> pre_op_in_range; | |||||
| if (ge::AttrUtils::GetListInt(*src, kPreOpInputShapeRange, pre_op_in_range)) { | |||||
| (void)ge::AttrUtils::SetListInt(*dst, kPreOpInputShapeRange, pre_op_in_range); | |||||
| } | |||||
| ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->MutableShape().GetDims().size())); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus InferShapePass::Infer(NodePtr &node) { | |||||
| bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||||
| auto opdesc = node->GetOpDesc(); | |||||
| if (node->Verify() != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| PrintInOutTensorShape(node, "before_infershape"); | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||||
| if (!is_unknown_graph) { | |||||
| auto inference_context = CreateInferenceContextPtr(context_map, node); | |||||
| GE_CHECK_NOTNULL(inference_context); | |||||
| GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); | |||||
| op.SetInferenceContext(inference_context); | |||||
| } | |||||
| graphStatus status = CallInferShapeFunc(node, op); | |||||
| if (status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | } | ||||
| if (!is_unknown_graph) { | |||||
| auto ctx_after_infer = op.GetInferenceContext(); | |||||
| if (ctx_after_infer != nullptr) { | |||||
| GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); | |||||
| if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { | |||||
| GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), | |||||
| ctx_after_infer->GetMarks().size()); | |||||
| (void)context_map.emplace(node, ctx_after_infer); | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| const auto &op_type = op_desc->GetType(); | |||||
| auto ret = op_desc->CallInferFunc(op); | |||||
| if (ret == GRAPH_PARAM_INVALID) { | |||||
| // Op ir no infer func, try to get infer func from operator factory | |||||
| auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | |||||
| if (node_op.IsEmpty()) { | |||||
| GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); | |||||
| return ret; | |||||
| } | |||||
| GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); | |||||
| auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||||
| node_op.BreakConnect(); | |||||
| if (temp_op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); | |||||
| GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { | |||||
| GELOGW("InferShapeAndType UpdateInputName failed"); | |||||
| for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
| if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { | |||||
| break; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | |||||
| if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { | |||||
| GELOGW("InferShapeAndType UpdateOutputName failed"); | |||||
| } | |||||
| op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); | |||||
| ret = op_desc->CallInferFunc(op); | |||||
| GELOGI("op CallInferFunc second. ret: %u", ret); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| graphStatus InferShapePass::UpdatePeerInputs(NodePtr &node) { | |||||
| bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||||
| if (is_unknown_graph) { | |||||
| PrintInOutTensorShape(node, "after_infershape when running"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| UpdateInputOutputOriginAttr(node); | |||||
| if (NodeUtils::UpdatePeerNodeInputDesc(node) != SUCCESS) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| PrintInOutTensorShape(node, "after_infershape"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| void InferShapePass::UpdateInputOutputOriginAttr(NodePtr &node) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | |||||
| auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||||
| if (output_tensor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (output_tensor->MutableShape().GetDims().empty()) { | |||||
| output_tensor->SetOriginShape(output_tensor->GetShape()); | |||||
| } | |||||
| ge::TensorUtils::SetRealDimCnt(*output_tensor, | |||||
| static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims().size())); | |||||
| output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||||
| // set output origin shape range | |||||
| std::vector<std::pair<int64_t, int64_t>> range; | |||||
| (void)output_tensor->GetShapeRange(range); | |||||
| output_tensor->SetOriginShapeRange(range); | |||||
| GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", node->GetName().c_str(), | |||||
| output_tensor->GetOriginShape().GetShapeSize(), | |||||
| TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||||
| } | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| auto input_tensor = op_desc->MutableInputDesc(in_anchor->GetIdx()); | |||||
| if (input_tensor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| // set input origin shape range | |||||
| std::vector<std::pair<int64_t, int64_t>> range; | |||||
| (void)input_tensor->GetShapeRange(range); | |||||
| input_tensor->SetOriginShapeRange(range); | |||||
| } | |||||
| } | |||||
| Status InferShapePass::DoRepassForLoopNode(NodePtr &node) { | |||||
| GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); | GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); | ||||
| bool need_repass = false; | bool need_repass = false; | ||||
| auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); | auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); | ||||
| @@ -150,13 +372,13 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | ||||
| "[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); | "[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); | ||||
| if (kNextIterationOpTypes.count(node_type) > 0) { | if (kNextIterationOpTypes.count(node_type) > 0) { | ||||
| return RePassNode(kMergeOpTypes); // Re-Pass Merge | |||||
| return RePassNode(kMergeOpTypes); // Re-Pass Merge | |||||
| } | } | ||||
| if (kMergeOpTypes.count(node_type) > 0) { | if (kMergeOpTypes.count(node_type) > 0) { | ||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | ||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | ||||
| return RePassNode(kSwitchOpTypes); // Re-Pass Switch | |||||
| return RePassNode(kSwitchOpTypes); // Re-Pass Switch | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -164,12 +386,110 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| if (kSwitchOpTypes.count(node_type) > 0) { | if (kSwitchOpTypes.count(node_type) > 0) { | ||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | ||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | ||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit | |||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit | |||||
| } else { | } else { | ||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit | |||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY | |||||
| graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| InferShapePass pass; | |||||
| std::set<NodePtr> unused_changed_nodes; | |||||
| return pass.InferAndUpdate(node, true, unused_changed_nodes); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY | |||||
| graphStatus InferShapePass::InferShapeAndType(NodePtr &node, bool before_subgraph) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| InferShapePass pass; | |||||
| std::set<NodePtr> unused_changed_nodes; | |||||
| return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes); | |||||
| } | |||||
| graphStatus InferShapeForRunning::Infer(NodePtr &node) { | |||||
| auto opdesc = node->GetOpDesc(); | |||||
| vector<ge::DataType> temp_dtype; | |||||
| for (auto &tensor_desc : opdesc->GetAllOutputsDescPtr()) { | |||||
| temp_dtype.emplace_back(tensor_desc->GetDataType()); | |||||
| } | |||||
| PrintInOutTensorShape(node, "before_infershape when running"); | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||||
| graphStatus status = CallInferShapeFuncForRunning(node, op); | |||||
| if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | |||||
| // ensure the dtype is not changed after infershape in running | |||||
| auto after_opdesc = node->GetOpDesc(); | |||||
| GE_IF_BOOL_EXEC(after_opdesc == nullptr, REPORT_INNER_ERROR("E19999", "param node has no opdesc, check invalid."); | |||||
| GELOGE(GRAPH_FAILED, "[Get][OpDesc] after_opdesc is null."); return GRAPH_FAILED); | |||||
| auto all_output_tensor = after_opdesc->GetAllOutputsDescPtr(); | |||||
| for (size_t i = 0; i < all_output_tensor.size(); ++i) { | |||||
| if (all_output_tensor.at(i)->GetDataType() != temp_dtype[i]) { | |||||
| GELOGD("Op %s output %zu need reset dtype,original dtype is %s, new dtype is %s", node->GetName().c_str(), i, | |||||
| TypeUtils::DataTypeToSerialString(all_output_tensor.at(i)->GetDataType()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(temp_dtype[i]).c_str()); | |||||
| all_output_tensor.at(i)->SetDataType(temp_dtype[i]); | |||||
| } | |||||
| } | |||||
| PrintInOutTensorShape(node, "after_infershape when running"); | |||||
| return GRAPH_SUCCESS; | |||||
| } else { | |||||
| REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| graphStatus InferShapeForRunning::CallInferShapeFuncForRunning(NodePtr &node, Operator &op) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| const auto &op_type = op_desc->GetType(); | |||||
| // Create InferenceContext to avoid null pointer access. | |||||
| const static std::set<std::string> force_context_op_types{"Enter", "Switch", "RefSwitch"}; | |||||
| if (force_context_op_types.count(op_type) > 0) { | |||||
| GELOGD("Set InferenceContext for node [%s]", op_desc->GetName().c_str()); | |||||
| op.SetInferenceContext(std::shared_ptr<InferenceContext>(InferenceContext::Create())); | |||||
| } | |||||
| // Get infer func and execute | |||||
| auto ret = op_desc->CallInferFunc(op); | |||||
| if (ret == GRAPH_PARAM_INVALID) { | |||||
| GELOGD("NodeUtils::GetNodeType return value is: [%s]", NodeUtils::GetNodeType(*node).c_str()); | |||||
| auto origin_type = NodeUtils::GetNodeType(*node); | |||||
| auto infer_func = ge::OperatorFactoryImpl::GetInferShapeFunc(origin_type); | |||||
| if (infer_func == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Failed to Get InferFunc. type is %s", origin_type.c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Get][InferFunc] failed. type is %s", origin_type.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| op_desc->AddInferFunc(infer_func); | |||||
| ret = op_desc->CallInferFunc(op); | |||||
| GELOGI("op CallInferFunc second. ret: %u", ret); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| bool InferShapeForRunning::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { | |||||
| bool changed = false; | |||||
| const auto &dst_dims = dst->GetShape().GetDims(); | |||||
| const auto &src_dims = src->GetShape().GetDims(); | |||||
| if (dst_dims != src_dims) { | |||||
| changed = true; | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY | |||||
| graphStatus InferShapeForRunning::InferShapeAndTypeForRunning(NodePtr &node, bool before_subgraph) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| InferShapeForRunning pass; | |||||
| std::set<NodePtr> unused_changed_nodes; | |||||
| return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -17,22 +17,38 @@ | |||||
| #ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | #ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | ||||
| #define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | #define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | ||||
| #include "graph/passes/base_pass.h" | |||||
| #include "graph/passes/infer_base_pass.h" | |||||
| namespace ge { | namespace ge { | ||||
| class InferShapePass : public BaseNodePass { | |||||
| class InferShapePass : public InferBasePass { | |||||
| public: | public: | ||||
| /// | |||||
| /// Entry of the InferShapePass optimizer | |||||
| /// @param [in] graph: Input ComputeGraph | |||||
| /// @return SUCCESS: Execution succeed | |||||
| /// @return OTHERS: Execution failed | |||||
| /// @author | |||||
| /// | |||||
| Status Run(ge::NodePtr &node) override; | |||||
| static void ClearContextMap(); | |||||
| graphStatus Infer(NodePtr &node) override; | |||||
| bool TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) override; | |||||
| graphStatus UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; | |||||
| void AnalyzeFailedInfo(const NodePtr &node) override; | |||||
| static graphStatus InferShapeAndType(NodePtr &node); // temp: visible static func | |||||
| static graphStatus InferShapeAndType(NodePtr &node, bool before_subgraph); // temp: visible static func | |||||
| private: | |||||
| graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); | |||||
| graphStatus UpdatePeerInputs(NodePtr &node) override; // only for infershape, will be deleted | |||||
| void UpdateInputOutputOriginAttr(NodePtr &node); // only for infershape, will be deleted | |||||
| Status DoRepassForLoopNode(NodePtr &node) override; // only for infershape, will be deleted | |||||
| Status RePassLoopNode(const NodePtr &node); // only for infershape, will be deleted | |||||
| }; | |||||
| class InferShapeForRunning : public InferBasePass { | |||||
| public: | |||||
| graphStatus Infer(NodePtr &node) override; | |||||
| bool TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) override; | |||||
| static graphStatus InferShapeAndTypeForRunning(NodePtr &node, bool before_subgraph); // temp: visible static func | |||||
| private: | private: | ||||
| Status RePassLoopNode(const NodePtr &node); | |||||
| graphStatus CallInferShapeFuncForRunning(NodePtr &node, Operator &op); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | ||||
| @@ -54,6 +54,7 @@ | |||||
| #include "graph/passes/hccl_group_pass.h" | #include "graph/passes/hccl_group_pass.h" | ||||
| #include "graph/passes/identity_pass.h" | #include "graph/passes/identity_pass.h" | ||||
| #include "graph/passes/infershape_pass.h" | #include "graph/passes/infershape_pass.h" | ||||
| #include "graph/passes/infer_value_range_pass.h" | |||||
| #include "graph/passes/merge_pass.h" | #include "graph/passes/merge_pass.h" | ||||
| #include "graph/passes/net_output_pass.h" | #include "graph/passes/net_output_pass.h" | ||||
| #include "graph/passes/no_use_reshape_remove_pass.h" | #include "graph/passes/no_use_reshape_remove_pass.h" | ||||
| @@ -1997,6 +1998,8 @@ Status GraphPrepare::InferShapeForPreprocess() { | |||||
| names_to_passes.emplace_back("MergePass", &merge_pass); | names_to_passes.emplace_back("MergePass", &merge_pass); | ||||
| InferShapePass infer_shape_pass; | InferShapePass infer_shape_pass; | ||||
| names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); | names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); | ||||
| InferValueRangePass infer_value_pass; | |||||
| names_to_passes.emplace_back("InferValuePass", &infer_value_pass); | |||||
| ReplaceWithEmptyConstPass replace_with_empty_const_pass; | ReplaceWithEmptyConstPass replace_with_empty_const_pass; | ||||
| names_to_passes.emplace_back("ReplaceWithEmptyConstPass", &replace_with_empty_const_pass); | names_to_passes.emplace_back("ReplaceWithEmptyConstPass", &replace_with_empty_const_pass); | ||||
| DimensionComputePass dimension_compute_pass; | DimensionComputePass dimension_compute_pass; | ||||
| @@ -219,7 +219,9 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | ||||
| @@ -478,7 +480,7 @@ set(GRAPH_BUILD_COMMON_SRC_FILES | |||||
| ) | ) | ||||
| set(GRAPH_PASS_COMMON_SRC_FILES | set(GRAPH_PASS_COMMON_SRC_FILES | ||||
| "${GE_CODE_DIR}/ge/graph/passes/pass_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/pass_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/base_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/base_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/variable_prepare_op_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/variable_prepare_op_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/variable_ref_delete_op_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/variable_ref_delete_op_pass.cc" | ||||
| @@ -532,7 +534,9 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | ||||
| "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" | "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | ||||
| @@ -702,6 +706,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/net_output_pass_unittest.cc" | "graph/passes/net_output_pass_unittest.cc" | ||||
| "graph/passes/no_use_reshape_remove_pass_unittest.cc" | "graph/passes/no_use_reshape_remove_pass_unittest.cc" | ||||
| "graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
| "graph/passes/infer_value_range_pass_unittest.cc" | |||||
| "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | ||||
| "graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
| "graph/passes/subgraph_const_migration_pass_unittest.cc" | "graph/passes/subgraph_const_migration_pass_unittest.cc" | ||||
| @@ -0,0 +1,281 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/passes/infer_value_range_pass.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph_builder_utils.h" | |||||
| #include "inc/external/graph/operator_reg.h" | |||||
| #include "inc/external/graph/operator.h" | |||||
| #include "inc/external/graph/operator_factory.h" | |||||
| #include "inc/graph/operator_factory_impl.h" | |||||
| #include "inc/kernel.h" | |||||
| #include "inc/kernel_factory.h" | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| class UtestGraphInferValueRangePass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| auto op_desc = std::make_shared<OpDesc>("AddN", "AddN"); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| TEST_F(UtestGraphInferValueRangePass, infer_pass_not_register) { | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_FLOAT16); | |||||
| auto addn_op_desc = std::make_shared<OpDesc>("AddN", "AddN"); | |||||
| addn_op_desc->AddInputDesc(ge_tensor_desc); | |||||
| addn_op_desc->AddOutputDesc(ge_tensor_desc); | |||||
| auto addn_op_node = graph->AddNode(addn_op_desc); | |||||
| InferValueRangePass infer_pass; | |||||
| EXPECT_EQ(infer_pass.Run(addn_op_node), SUCCESS); | |||||
| } | |||||
| auto ShapeValueInfer = [&](Operator &op) { | |||||
| auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
| auto output_tensor_desc = op_desc->MutableOutputDesc(0); | |||||
| std::vector<std::pair<int64_t, int64_t>> in_shape_range; | |||||
| op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); | |||||
| if (!in_shape_range.empty()) { | |||||
| output_tensor_desc->SetValueRange(in_shape_range); | |||||
| } | |||||
| return SUCCESS; | |||||
| }; | |||||
| TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_1_not_infer) { | |||||
| INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueInfer); | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range = {make_pair(1, 1), make_pair(1, 1), | |||||
| make_pair(4, 4), make_pair(192, 192)}; | |||||
| ge_tensor_desc.SetShapeRange(shape_range); | |||||
| GeTensorDesc output_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); | |||||
| auto op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
| op_desc->AddInputDesc(ge_tensor_desc); | |||||
| op_desc->AddOutputDesc(output_tensor_desc); | |||||
| auto op_node = graph->AddNode(op_desc); | |||||
| InferValueRangePass infer_pass; | |||||
| EXPECT_EQ(infer_pass.Run(op_node), SUCCESS); | |||||
| auto output_0_desc = op_node->GetOpDesc()->GetOutputDesc(0); | |||||
| std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
| output_0_desc.GetValueRange(value_range); | |||||
| EXPECT_EQ(value_range.empty(), true); | |||||
| } | |||||
| TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_1_infer) { | |||||
| // sqrt -> shape -> Output | |||||
| INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueInfer); | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| GeTensorDesc sqrt_tensor_desc(GeShape({-1, -1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range = {make_pair(1, 100), make_pair(1, 240), | |||||
| make_pair(4, 4), make_pair(192, 192)}; | |||||
| sqrt_tensor_desc.SetShapeRange(shape_range); | |||||
| auto sqrt_op_desc = std::make_shared<OpDesc>("Sqrt", "Sqrt"); | |||||
| sqrt_op_desc->AddInputDesc(sqrt_tensor_desc); | |||||
| sqrt_op_desc->AddOutputDesc(sqrt_tensor_desc); | |||||
| auto sqrt_node = graph->AddNode(sqrt_op_desc); | |||||
| GeTensorDesc shape_output_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); | |||||
| auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
| shape_op_desc->AddInputDesc(sqrt_tensor_desc); | |||||
| shape_op_desc->AddOutputDesc(shape_output_desc); | |||||
| auto shape_node = graph->AddNode(shape_op_desc); | |||||
| GeTensorDesc Output_in_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); | |||||
| auto Output_op_desc = std::make_shared<OpDesc>("Output", "Output"); | |||||
| Output_op_desc->AddInputDesc(Output_in_tensor_desc); | |||||
| auto Output_node = graph->AddNode(Output_op_desc); | |||||
| ge::GraphUtils::AddEdge(sqrt_node->GetOutDataAnchor(0), shape_node->GetInDataAnchor(0)); | |||||
| ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), Output_node->GetInDataAnchor(0)); | |||||
| EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); | |||||
| InferValueRangePass infer_pass; | |||||
| auto ret = infer_pass.Run(shape_node); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| auto output_0_desc = shape_node->GetOpDesc()->GetOutputDesc(0); | |||||
| std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
| output_0_desc.GetValueRange(value_range); | |||||
| EXPECT_EQ(value_range.size(), 4); | |||||
| std::vector<int64_t> target_value_range = {1, 100, 1, 240, 4, 4, 192, 192}; | |||||
| std::vector<int64_t> output_value_range; | |||||
| for (auto pair : value_range) { | |||||
| output_value_range.push_back(pair.first); | |||||
| output_value_range.push_back(pair.second); | |||||
| } | |||||
| EXPECT_EQ(target_value_range, output_value_range); | |||||
| auto in_0_desc = Output_node->GetOpDesc()->GetInputDesc(0); | |||||
| value_range.clear(); | |||||
| in_0_desc.GetValueRange(value_range); | |||||
| EXPECT_EQ(value_range.size(), 0); | |||||
| /* | |||||
| INFER_VALUE_RANGE_DEFAULT_REG(Output); | |||||
| ret = infer_pass.Run(Output_node); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| auto in_0_desc_after_infer = Output_node->GetOpDesc()->GetInputDesc(0); | |||||
| value_range.clear(); | |||||
| in_0_desc_after_infer.GetValueRange(value_range); | |||||
| EXPECT_EQ(value_range.size(), 4); | |||||
| output_value_range.clear(); | |||||
| for (auto pair : value_range) { | |||||
| output_value_range.push_back(pair.first); | |||||
| output_value_range.push_back(pair.second); | |||||
| } | |||||
| EXPECT_EQ(target_value_range, output_value_range); | |||||
| auto out_0_desc = Output_node->GetOpDesc()->GetOutputDesc(0); | |||||
| value_range.clear(); | |||||
| out_0_desc.GetValueRange(value_range); | |||||
| EXPECT_EQ(value_range.size(), 0); | |||||
| */ | |||||
| } | |||||
| class AddKernel : public Kernel { | |||||
| public: | |||||
| Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input, | |||||
| std::vector<ge::GeTensorPtr> &v_output) override { | |||||
| vector<int64_t> data_vec; | |||||
| auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); | |||||
| auto x1_data = reinterpret_cast<const int64_t *>(input[0]->GetData().data()); | |||||
| auto x2_data = reinterpret_cast<const int64_t *>(input[1]->GetData().data()); | |||||
| for (size_t i = 0; i < data_num; i++) { | |||||
| auto x_index = *(x1_data + i); | |||||
| auto y_index = *(x2_data + i); | |||||
| data_vec.push_back(x_index + y_index); | |||||
| } | |||||
| GeTensorPtr const_tensor = std::make_shared<ge::GeTensor>(input[0]->GetTensorDesc(), (uint8_t *)data_vec.data(), | |||||
| data_num * sizeof(int64_t)); | |||||
| v_output.emplace_back(const_tensor); | |||||
| return SUCCESS; | |||||
| } | |||||
| }; | |||||
| REGISTER_KERNEL(ADD, AddKernel); | |||||
| TEST_F(UtestGraphInferValueRangePass, infer_pass_when_call_2_infer) { | |||||
| // shape --- add --- sqrt | |||||
| // constant / | |||||
| INFER_VALUE_RANGE_DEFAULT_REG(Add); | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| vector<int64_t> dims_vec = {4}; | |||||
| vector<int64_t> data_vec = {1, 1, 1, 1}; | |||||
| GeTensorDesc const_tensor_desc(ge::GeShape(dims_vec), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
| GeTensorPtr const_tensor = | |||||
| std::make_shared<ge::GeTensor>(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t)); | |||||
| auto const_op_desc = std::make_shared<OpDesc>("Constant", "Constant"); | |||||
| const_op_desc->AddOutputDesc(const_tensor_desc); | |||||
| EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); | |||||
| auto const_node = graph->AddNode(const_op_desc); | |||||
| GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
| std::vector<std::pair<int64_t, int64_t>> value_range = {make_pair(1, 100), make_pair(1, 240), | |||||
| make_pair(4, 4), make_pair(192, 192)}; | |||||
| shape_tensor_desc.SetValueRange(value_range); | |||||
| auto shape_op_desc = std::make_shared<OpDesc>("Shape", "Shape"); | |||||
| shape_op_desc->AddOutputDesc(shape_tensor_desc); | |||||
| auto shape_node = graph->AddNode(shape_op_desc); | |||||
| GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); | |||||
| auto add_op_desc = std::make_shared<OpDesc>("Add", "Add"); | |||||
| add_op_desc->AddInputDesc(shape_tensor_desc); | |||||
| add_op_desc->AddInputDesc(const_tensor_desc); | |||||
| add_op_desc->AddOutputDesc(add_tensor_desc); | |||||
| auto add_node = graph->AddNode(add_op_desc); | |||||
| ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); | |||||
| ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); | |||||
| InferValueRangePass infer_pass; | |||||
| EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); | |||||
| auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); | |||||
| std::vector<std::pair<int64_t, int64_t>> out_value_range; | |||||
| output_0_desc.GetValueRange(out_value_range); | |||||
| EXPECT_EQ(out_value_range.size(), 4); | |||||
| std::vector<int64_t> target_value_range = {2, 101, 2, 241, 5, 5, 193, 193}; | |||||
| std::vector<int64_t> output_value_range; | |||||
| for (auto pair : out_value_range) { | |||||
| output_value_range.push_back(pair.first); | |||||
| output_value_range.push_back(pair.second); | |||||
| } | |||||
| EXPECT_EQ(target_value_range, output_value_range); | |||||
| } | |||||
| TEST_F(UtestGraphInferValueRangePass, test_value_range_infer_and_set_get) { | |||||
| using std::make_pair; | |||||
| std::function<ge::graphStatus(ge::Operator &)> ShapeValueInfer_ = [](ge::Operator &op) -> ge::graphStatus { | |||||
| auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
| auto output_tensor_desc = op_desc->MutableOutputDesc(0); | |||||
| std::vector<std::pair<int64_t, int64_t>> in_shape_range; | |||||
| op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); | |||||
| if (!in_shape_range.empty()) { | |||||
| output_tensor_desc->SetValueRange(in_shape_range); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| }; | |||||
| INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueInfer_); | |||||
| string op_type = "Shape"; | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| auto shape_op_desc = std::make_shared<OpDesc>("node_name", op_type); | |||||
| GeTensorDesc tensor_desc(GeShape({-1, -1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range = {make_pair(1, 100), make_pair(1, 240), | |||||
| make_pair(4, 4), make_pair(192, 192)}; | |||||
| tensor_desc.SetShapeRange(shape_range); | |||||
| shape_op_desc->AddInputDesc(tensor_desc); | |||||
| GeTensorDesc out_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); | |||||
| shape_op_desc->AddOutputDesc(out_tensor_desc); | |||||
| auto shape_node = graph->AddNode(shape_op_desc); | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(shape_node); | |||||
| auto ret = shape_node->GetOpDesc()->CallInferValueRangeFunc(op); | |||||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||||
| auto output_0_desc = shape_node->GetOpDesc()->GetOutputDesc(0); | |||||
| std::vector<std::pair<int64_t, int64_t>> value_range; | |||||
| output_0_desc.GetValueRange(value_range); | |||||
| EXPECT_EQ(value_range.size(), 4); | |||||
| std::vector<int64_t> target_value_range = {1, 100, 1, 240, 4, 4, 192, 192}; | |||||
| std::vector<int64_t> output_value_range; | |||||
| for (auto pair : value_range) { | |||||
| output_value_range.push_back(pair.first); | |||||
| output_value_range.push_back(pair.second); | |||||
| } | |||||
| EXPECT_EQ(target_value_range, output_value_range); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include <operator_factory_impl.h> | |||||
| #define protected public | #define protected public | ||||
| #define private public | #define private public | ||||
| @@ -22,9 +23,12 @@ | |||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/operator_factory.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "graph_builder_utils.h" | #include "graph_builder_utils.h" | ||||
| #include "inc/external/graph/operator_reg.h" | |||||
| #include "inc/external/graph/operator.h" | |||||
| #include "inc/external/graph/operator_factory.h" | |||||
| #include "inc/graph/operator_factory_impl.h" | |||||
| using namespace std; | using namespace std; | ||||
| using namespace testing; | using namespace testing; | ||||
| @@ -35,6 +39,113 @@ class UtestGraphInfershapePass : public testing::Test { | |||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| /* | |||||
| * data1 const1 | |||||
| * \ / | |||||
| * case1 | |||||
| * | | |||||
| * relu10 | |||||
| * | | |||||
| * netoutput | |||||
| */ | |||||
| ut::GraphBuilder ParentGraphBuilder() { | |||||
| ut::GraphBuilder builder = ut::GraphBuilder("g1"); | |||||
| auto data1 = builder.AddNode("data1", "Data", 0, 1); | |||||
| std::vector<int64_t> const_shape = {1}; | |||||
| auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_INT32, const_shape); | |||||
| auto case1 = builder.AddNode("case1", CASE, 2, 1); | |||||
| auto relu1 = builder.AddNode("relu10", "Relu", 1, 1); | |||||
| auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
| int32_t weight[1] = {1}; | |||||
| GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); | |||||
| GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight)); | |||||
| OpDescUtils::SetWeights(const1, {tensor}); | |||||
| builder.AddDataEdge(data1, 0, case1, 0); | |||||
| builder.AddDataEdge(const1, 0, case1, 1); | |||||
| builder.AddDataEdge(case1, 0, relu1, 0); | |||||
| builder.AddDataEdge(relu1, 0, netoutput, 0); | |||||
| return builder; | |||||
| } | |||||
| /* | |||||
| * data1 data2 | |||||
| * \ / | |||||
| * switch | |||||
| * / \ | |||||
| * relu1 relu2 | |||||
| * \ / | |||||
| * merge | |||||
| * | | |||||
| * netoutput | |||||
| */ | |||||
| ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) { | |||||
| ut::GraphBuilder builder = ut::GraphBuilder(graph_name); | |||||
| std::vector<int64_t> shape1 = {2,2}; | |||||
| string data1_name = "data1_" + std::to_string(num); | |||||
| auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1); | |||||
| auto data1_desc = data1->GetOpDesc(); | |||||
| EXPECT_NE(data1_desc, nullptr); | |||||
| AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); | |||||
| std::vector<int64_t> shape2 = {3,3}; | |||||
| string data2_name = "data2_" + std::to_string(num); | |||||
| auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2); | |||||
| auto data2_desc = data2->GetOpDesc(); | |||||
| EXPECT_NE(data2_desc, nullptr); | |||||
| AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); | |||||
| string switch_name = "switch_" + std::to_string(num); | |||||
| auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); | |||||
| string relu1_name = "relu1_" + std::to_string(num); | |||||
| auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1); | |||||
| string relu2_name = "relu2_" + std::to_string(num); | |||||
| auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1); | |||||
| string merge_name = "merge_" + std::to_string(num); | |||||
| auto merge = builder.AddNode(merge_name, "Merge", 2, 1); | |||||
| std::vector<int64_t> shape7 = {8,8}; | |||||
| string output_name = "output_" + std::to_string(num); | |||||
| auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7); | |||||
| auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); | |||||
| EXPECT_NE(input0_desc, nullptr); | |||||
| AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); | |||||
| builder.AddDataEdge(data1, 0, switch1, 0); | |||||
| builder.AddDataEdge(data2, 0, switch1, 1); | |||||
| builder.AddDataEdge(switch1, 0, relu1, 0); | |||||
| builder.AddDataEdge(switch1, 1, relu2, 0); | |||||
| builder.AddDataEdge(relu1, 0, merge, 0); | |||||
| builder.AddDataEdge(relu2, 0, merge, 1); | |||||
| builder.AddDataEdge(merge, 0, netoutput, 0); | |||||
| return builder; | |||||
| } | |||||
| void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { | |||||
| auto case_node = parent_graph->FindNode("case1"); | |||||
| EXPECT_NE(case_node, nullptr); | |||||
| for (uint32_t i = 0; i < branch_num; ++i) { | |||||
| string name = "Branch_Graph_" + std::to_string(i); | |||||
| auto builder_subgraph = SwitchSubgraphBuilder(name, i); | |||||
| auto switch_subgraph = builder_subgraph.GetGraph(); | |||||
| case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); | |||||
| case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); | |||||
| switch_subgraph->SetParentNode(case_node); | |||||
| switch_subgraph->SetParentGraph(parent_graph); | |||||
| EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); | |||||
| } | |||||
| } | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | ||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | ||||
| op_desc->SetStreamId(0); | op_desc->SetStreamId(0); | ||||
| @@ -158,4 +269,218 @@ TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { | |||||
| EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS); | EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestGraphInfershapePass, infer_with_case_subgraph) { | |||||
| auto builder = ParentGraphBuilder(); | |||||
| auto parent_graph = builder.GetGraph(); | |||||
| AddCaseSubgraph(parent_graph, 2); | |||||
| auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
| EXPECT_EQ(subgraphs.size(), 2); | |||||
| auto case_node = parent_graph->FindNode("case1"); | |||||
| EXPECT_NE(case_node, nullptr); | |||||
| InferShapePass infershape_pass; | |||||
| EXPECT_EQ(infershape_pass.Run(case_node), SUCCESS); | |||||
| std::vector<int64_t> target_dims_0 = {1, 1, 224, 224}; | |||||
| std::vector<int64_t> target_dims_1 = {1}; | |||||
| { | |||||
| auto data_node = subgraphs[0]->FindNode("data1_0"); | |||||
| auto dims = data_node->GetOpDesc()->GetInputDescPtr(0)->GetShape().GetDims(); | |||||
| EXPECT_EQ(dims, target_dims_0); | |||||
| data_node = subgraphs[0]->FindNode("data2_0"); | |||||
| dims = data_node->GetOpDesc()->GetInputDescPtr(0)->GetShape().GetDims(); | |||||
| EXPECT_EQ(dims, target_dims_1); | |||||
| } | |||||
| infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; | |||||
| EXPECT_EQ(infershape_pass.Run(case_node), SUCCESS); | |||||
| { | |||||
| auto dims = case_node->GetOpDesc()->GetOutputDescPtr(0)->GetShape().GetDims(); | |||||
| std::vector<int64_t> out_target_dims = {8, 8}; | |||||
| EXPECT_EQ(out_target_dims, dims); | |||||
| } | |||||
| } | |||||
| /* | |||||
| * data1 const1 | |||||
| * \ / | |||||
| * while | |||||
| * / \ | |||||
| * relu1 netoutput | |||||
| */ | |||||
| ut::GraphBuilder ParentWhileGraphBuilder() { | |||||
| ut::GraphBuilder builder = ut::GraphBuilder("g1"); | |||||
| auto data1 = builder.AddNode("data1", "Data", 0, 1); | |||||
| std::vector<int64_t> const_shape = {1}; | |||||
| auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_FLOAT, const_shape); | |||||
| auto case1 = builder.AddNode("case1", WHILE, 2, 2); | |||||
| auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); | |||||
| auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
| int32_t weight[1] = {1}; | |||||
| GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_FLOAT); | |||||
| GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight)); | |||||
| OpDescUtils::SetWeights(const1, {tensor}); | |||||
| builder.AddDataEdge(data1, 0, case1, 0); | |||||
| builder.AddDataEdge(const1, 0, case1, 1); | |||||
| builder.AddDataEdge(case1, 0, relu1, 0); | |||||
| builder.AddDataEdge(case1, 1, netoutput, 0); | |||||
| return builder; | |||||
| } | |||||
| /* | |||||
| * data1 data2 | |||||
| * \ / | |||||
| * switch | |||||
| * | | | |||||
| * \ / | |||||
| * netoutput | |||||
| */ | |||||
| ut::GraphBuilder WhileSubgraphBuilder(string graph_name, uint32_t num) { | |||||
| ut::GraphBuilder builder = ut::GraphBuilder(graph_name); | |||||
| std::vector<int64_t> shape1 = {2,2}; | |||||
| string data1_name = "data1_" + std::to_string(num); | |||||
| auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_FLOAT, shape1); | |||||
| auto data1_desc = data1->GetOpDesc(); | |||||
| EXPECT_NE(data1_desc, nullptr); | |||||
| AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); | |||||
| std::vector<int64_t> shape2 = {3,3}; | |||||
| string data2_name = "data2_" + std::to_string(num); | |||||
| auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_FLOAT, shape2); | |||||
| auto data2_desc = data2->GetOpDesc(); | |||||
| EXPECT_NE(data2_desc, nullptr); | |||||
| AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); | |||||
| string switch_name = "switch_" + std::to_string(num); | |||||
| auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); | |||||
| std::vector<int64_t> shape7 = {8,8,8,8}; | |||||
| string output_name = "output_" + std::to_string(num); | |||||
| auto netoutput = builder.AddNode(output_name, NETOUTPUT, 2, 0, FORMAT_NCHW, DT_FLOAT, shape7); | |||||
| auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); | |||||
| EXPECT_NE(input0_desc, nullptr); | |||||
| AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); | |||||
| auto input1_desc = netoutput->GetOpDesc()->MutableInputDesc(1); | |||||
| EXPECT_NE(input1_desc, nullptr); | |||||
| AttrUtils::SetInt(input1_desc, "_parent_node_index", 1); | |||||
| builder.AddDataEdge(data1, 0, switch1, 0); | |||||
| builder.AddDataEdge(data2, 0, switch1, 1); | |||||
| builder.AddDataEdge(switch1, 0, netoutput, 0); | |||||
| builder.AddDataEdge(switch1, 1, netoutput, 1); | |||||
| return builder; | |||||
| } | |||||
| void AddWhileSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { | |||||
| auto case_node = parent_graph->FindNode("case1"); | |||||
| EXPECT_NE(case_node, nullptr); | |||||
| for (uint32_t i = 0; i < branch_num; ++i) { | |||||
| string name = "Branch_Graph_" + std::to_string(i); | |||||
| auto builder_subgraph = WhileSubgraphBuilder(name, i); | |||||
| auto switch_subgraph = builder_subgraph.GetGraph(); | |||||
| case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); | |||||
| case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); | |||||
| switch_subgraph->SetParentNode(case_node); | |||||
| switch_subgraph->SetParentGraph(parent_graph); | |||||
| EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); | |||||
| } | |||||
| } | |||||
| TEST_F(UtestGraphInfershapePass, infer_with_while_subgraph) { | |||||
| auto builder = ParentWhileGraphBuilder(); | |||||
| auto parent_graph = builder.GetGraph(); | |||||
| AddWhileSubgraph(parent_graph, 1); | |||||
| auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
| EXPECT_EQ(subgraphs.size(), 1); | |||||
| auto case_node = parent_graph->FindNode("case1"); | |||||
| EXPECT_NE(case_node, nullptr); | |||||
| InferShapePass infershape_pass; | |||||
| EXPECT_EQ(infershape_pass.Run(case_node), SUCCESS); | |||||
| std::vector<int64_t> target_dims_0 = {1, 1, 224, 224}; | |||||
| std::vector<int64_t> target_dims_1 = {1}; | |||||
| { | |||||
| auto data_node = subgraphs[0]->FindNode("data1_0"); | |||||
| auto dims = data_node->GetOpDesc()->GetInputDescPtr(0)->GetShape().GetDims(); | |||||
| EXPECT_EQ(dims, target_dims_0); | |||||
| data_node = subgraphs[0]->FindNode("data2_0"); | |||||
| dims = data_node->GetOpDesc()->GetInputDescPtr(0)->GetShape().GetDims(); | |||||
| EXPECT_EQ(dims, target_dims_1); | |||||
| } | |||||
| infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; | |||||
| EXPECT_EQ(infershape_pass.Run(case_node), SUCCESS); | |||||
| { | |||||
| auto dims = case_node->GetOpDesc()->GetOutputDescPtr(0)->GetShape().GetDims(); | |||||
| std::vector<int64_t> out_target_dims = {-1, -1, -1, -1}; | |||||
| EXPECT_EQ(out_target_dims, dims); | |||||
| } | |||||
| } | |||||
| TEST_F(UtestGraphInfershapePass, infer_with_while_subgraph_failed) { | |||||
| auto builder = ParentWhileGraphBuilder(); | |||||
| auto parent_graph = builder.GetGraph(); | |||||
| AddWhileSubgraph(parent_graph, 2); | |||||
| auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
| EXPECT_EQ(subgraphs.size(), 2); | |||||
| auto case_node = parent_graph->FindNode("case1"); | |||||
| EXPECT_NE(case_node, nullptr); | |||||
| InferShapePass infershape_pass; | |||||
| infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; | |||||
| EXPECT_EQ(infershape_pass.Run(case_node), GE_GRAPH_INFERSHAPE_FAILED); | |||||
| } | |||||
| auto InferFunc = [&](Operator &op) { | |||||
| return GRAPH_SUCCESS; | |||||
| }; | |||||
| TEST_F(UtestGraphInfershapePass, infer_forrunning_with_while_subgraph) { | |||||
| auto builder = ParentWhileGraphBuilder(); | |||||
| auto parent_graph = builder.GetGraph(); | |||||
| AddWhileSubgraph(parent_graph, 1); | |||||
| auto subgraphs = parent_graph->GetAllSubgraphs(); | |||||
| EXPECT_EQ(subgraphs.size(), 1); | |||||
| OperatorFactoryImpl::RegisterInferShapeFunc("Relu", InferFunc); | |||||
| auto relu_node = parent_graph->FindNode("relu1"); | |||||
| EXPECT_NE(relu_node, nullptr); | |||||
| InferShapeForRunning infershape_for_running; | |||||
| EXPECT_EQ(infershape_for_running.Run(relu_node), SUCCESS); | |||||
| } | |||||
| TEST_F(UtestGraphInfershapePass, infer_static_func) { | |||||
| auto builder = ut::GraphBuilder("test_graph"); | |||||
| auto data_1 = builder.AddNode("data_1", DATA, 0, 1); | |||||
| auto data_2 = builder.AddNode("data_2", DATA, 0, 1); | |||||
| auto add = builder.AddNode("Add", "Add", 2, 1); | |||||
| builder.AddDataEdge(data_1, 0, add, 0); | |||||
| builder.AddDataEdge(data_2, 0, add, 1); | |||||
| auto test_graph = builder.GetGraph(); | |||||
| // OperatorFactoryImpl::CreateOperator("Add", "Flatten"); | |||||
| auto test_node = test_graph->FindNode("Add"); | |||||
| auto ret = InferShapePass::InferShapeAndType(test_node); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| OperatorFactoryImpl::RegisterInferShapeFunc("Add", InferFunc); | |||||
| ret = InferShapePass::InferShapeAndType(test_node); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| ret = InferShapePass::InferShapeAndType(test_node, true); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| ret = InferShapeForRunning::InferShapeAndTypeForRunning(test_node, true); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||