| @@ -298,7 +298,9 @@ set(TRAIN_SRC_LIST | |||
| "graph/passes/hccl_continuous_memcpy_pass.cc" | |||
| "graph/passes/identity_pass.cc" | |||
| "graph/passes/ref_identity_delete_op_pass.cc" | |||
| "graph/passes/infer_base_pass.cc" | |||
| "graph/passes/infershape_pass.cc" | |||
| #"graph/passes/infer_value_range_pass.cc" | |||
| "graph/passes/iterator_op_pass.cc" | |||
| "graph/passes/link_gen_mask_nodes_pass.cc" | |||
| "graph/passes/merge_pass.cc" | |||
| @@ -552,7 +554,9 @@ set(INFER_SRC_LIST | |||
| "graph/passes/shape_operate_op_remove_pass.cc" | |||
| "graph/passes/assert_pass.cc" | |||
| "graph/passes/dropout_pass.cc" | |||
| "graph/passes/infer_base_pass.cc" | |||
| "graph/passes/infershape_pass.cc" | |||
| #"graph/passes/infer_value_range_pass.cc" | |||
| "graph/passes/unused_const_pass.cc" | |||
| "graph/passes/permute_pass.cc" | |||
| "graph/passes/ctrl_edge_transfer_pass.cc" | |||
| @@ -0,0 +1,636 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "infer_base_pass.h" | |||
| #include "common/ge/ge_util.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "framework/common/util.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/debug/ge_util.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/node_utils.h" | |||
| #include "graph/utils/tensor_utils.h" | |||
| #include "graph/utils/type_utils.h" | |||
| namespace ge { | |||
| namespace { | |||
| string Serial(const vector<int64_t> &dims) { | |||
| string serial_string; | |||
| serial_string += "["; | |||
| for (int64_t dim : dims) { | |||
| serial_string += std::to_string(dim) + " "; | |||
| } | |||
| serial_string += "]"; | |||
| return serial_string; | |||
| } | |||
| void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||
| desc_str += "["; | |||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
| (void)desc->GetShapeRange(shape_range); | |||
| for (const auto &pair : shape_range) { | |||
| desc_str += "{"; | |||
| desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||
| desc_str += "},"; | |||
| } | |||
| desc_str += "]"; | |||
| shape_range.clear(); | |||
| (void)desc->GetOriginShapeRange(shape_range); | |||
| for (const auto &pair : shape_range) { | |||
| desc_str += ",{"; | |||
| desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||
| desc_str += "},"; | |||
| } | |||
| } | |||
| graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput, | |||
| const ConstNodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) { | |||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
| auto sub_node = sub_nodes.at(i - 1); | |||
| if (sub_node->GetType() == NETOUTPUT) { | |||
| netoutput = sub_node; | |||
| } | |||
| if (sub_node->GetType() == DATA) { | |||
| if (sub_node->GetOpDesc() == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| int ref_i; | |||
| if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) { | |||
| REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!", | |||
| sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!", | |||
| sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } // namespace | |||
| Status InferBasePass::Run(NodePtr &node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| bool need_infer = NeedInfer(node); | |||
| if (!need_infer) { | |||
| GELOGD("Node %s does not need to infer.", node->GetName().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| std::set<NodePtr> changed_nodes; | |||
| auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| (void)AnalyzeFailedInfo(node); | |||
| return GE_GRAPH_INFERSHAPE_FAILED; | |||
| } | |||
| // AddChangedNodesImmediateRepass(changed_nodes); | |||
| auto status = DoRepassForLoopNode(node); | |||
| if (status != SUCCESS) { | |||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "repass failed. node: %s", node->GetName().c_str()); | |||
| return GE_GRAPH_INFERSHAPE_FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| bool InferBasePass::NeedInfer(const NodePtr &node) { return true; } | |||
| void InferBasePass::AnalyzeFailedInfo(const NodePtr &node) { /* Analyze and select failed info*/ } | |||
| Status InferBasePass::DoRepassForLoopNode(NodePtr &node) { return SUCCESS; } | |||
| void InferBasePass::AddChangedNodesImmediateRepass(std::set<NodePtr> &changed_nodes) { | |||
| for (const auto &node_ele : changed_nodes) { | |||
| AddImmediateRePassNode(node_ele); | |||
| } | |||
| } | |||
| graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | |||
| bool contain_subgraph = ContainsSubgraph(node); | |||
| if (contain_subgraph && before_subgraph) { | |||
| auto ret = UpdateTensorDescToSubgraphData(node, changed_nodes); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| return ret; | |||
| } | |||
| } | |||
| auto ret = Infer(node); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| return ret; | |||
| } | |||
| if (contain_subgraph && !before_subgraph) { | |||
| return UpdateTensorDescToParentNode(node, changed_nodes); | |||
| } | |||
| return UpdateTensorDescToPeerInputs(node, changed_nodes); | |||
| } | |||
| bool InferBasePass::ContainsSubgraph(const NodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (sub_graph_names.empty()) { | |||
| return false; | |||
| } | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| if (name.empty()) { | |||
| continue; | |||
| } | |||
| auto sub_graph = root_graph->GetSubgraph(name); | |||
| if (sub_graph != nullptr) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes) { | |||
| // if infer again, update output of while into subgraph data node | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (sub_graph_names.empty()) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| if (name.empty()) { | |||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
| continue; | |||
| } | |||
| auto sub_graph = root_graph->GetSubgraph(name); | |||
| if (sub_graph == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
| GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (const auto &node_sub : sub_graph->GetDirectNode()) { | |||
| if (node_sub->GetType() != DATA) { | |||
| continue; | |||
| } | |||
| int ref_i; | |||
| auto data_opdesc = node_sub->GetOpDesc(); | |||
| if (data_opdesc == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||
| node->GetName().c_str()); | |||
| GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute", | |||
| name.c_str(), node->GetName().c_str()); | |||
| GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { | |||
| continue; | |||
| } | |||
| auto input_desc = op_desc->MutableInputDesc(ref_i); | |||
| if (input_desc == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", | |||
| "The ref index(%d) on the data %s on the sub graph %s " | |||
| "parent node %s are incompatible, inputs num %u", | |||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||
| node->GetAllInDataAnchorsSize()); | |||
| GE_LOGE( | |||
| "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s " | |||
| "parent node %s are incompatible, inputs num %u", | |||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), | |||
| node->GetName().c_str()); | |||
| // if need infer again, refresh subgraph input with output | |||
| bool is_infer_again = false; | |||
| AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, is_infer_again); | |||
| if (is_infer_again) { | |||
| input_desc = op_desc->MutableOutputDesc(ref_i); | |||
| if (input_desc == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", | |||
| "The ref index(%d) on the data %s on the subgraph %s " | |||
| "parent node %s are incompatible, outputs num %u.", | |||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||
| node->GetAllOutDataAnchorsSize()); | |||
| GELOGE(PARAM_INVALID, | |||
| "[Call][MutableOutputDesc] The ref index(%d) on the data %s on the subgraph %s " | |||
| "parent node %s are incompatible, outputs num %u.", | |||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||
| node->GetAllOutDataAnchorsSize()); | |||
| } | |||
| GELOGD("Update input desc of data %s on the sub graph %s of node %s,output idx: %d from [%s] to [%s]", | |||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), ref_i, | |||
| data_opdesc->GetInputDescPtr(0)->GetShape().ToString().c_str(), | |||
| input_desc->GetShape().ToString().c_str()); | |||
| } | |||
| // auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||
| bool input_changed = false; | |||
| auto data_input_desc = data_opdesc->MutableInputDesc(0); | |||
| auto ret = UpdateTensorDesc(input_desc, data_input_desc,input_changed); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s", | |||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
| GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(), | |||
| name.c_str(), node->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| // ret = data_opdesc->UpdateOutputDesc(0, *input_desc); | |||
| bool output_changed = false; | |||
| auto data_output_desc = data_opdesc->MutableOutputDesc(0); | |||
| ret = UpdateTensorDesc(input_desc, data_output_desc,output_changed); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s", | |||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
| GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed", | |||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| if (input_changed || output_changed) { | |||
| changed_nodes.insert(node_sub); | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (sub_graph_names.empty()) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize()); | |||
| std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize()); | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| if (name.empty()) { | |||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
| continue; | |||
| } | |||
| auto sub_graph = root_graph->GetSubgraph(name); | |||
| if (sub_graph == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
| GE_LOGE("[Get][Subgraph] Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| NodePtr netoutput = nullptr; | |||
| auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| return ret; | |||
| } | |||
| if (netoutput == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(), | |||
| node->GetName().c_str()); | |||
| GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto netoutput_opdesc = netoutput->GetOpDesc(); | |||
| if (netoutput_opdesc == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", | |||
| name.c_str(), node->GetName().c_str()); | |||
| GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { | |||
| auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); | |||
| if (edge_desc == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", | |||
| "Invalid NetOutput node on sub graph %s, parent node %s, " | |||
| "can not find input tensor %d", | |||
| name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); | |||
| GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", | |||
| name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(), | |||
| edge_desc->GetShape().GetDimNum()); | |||
| int ref_i; | |||
| if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||
| continue; | |||
| } | |||
| GELOGI("Parent node index of edge desc is %d", ref_i); | |||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| ref_out_tensors[ref_i].emplace_back(*edge_desc); | |||
| } | |||
| } | |||
| if (node->GetType() == WHILE) { | |||
| return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes); | |||
| } | |||
| return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes); | |||
| } | |||
| graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_data_tensors, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||
| std::set<NodePtr> &changed_nodes) { | |||
| GELOGD("Enter update parent node shape for class while op process"); | |||
| if (ref_data_tensors.size() != ref_out_tensors.size()) { | |||
| REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!", | |||
| node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(), | |||
| ref_out_tensors.size()); | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!", | |||
| node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (size_t i = 0; i < ref_data_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].size() != 1) { | |||
| REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!"); | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| bool need_infer_again = false; | |||
| // check input and output | |||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].empty()) { | |||
| continue; | |||
| } | |||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||
| auto out_shape = ref_out_tensor.MutableShape(); | |||
| vector<std::pair<int64_t, int64_t>> data_shape_range; | |||
| // ref_i's data and output tensor shape should be same | |||
| for (auto &tensor : ref_data_tensors[i]) { | |||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
| REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output", | |||
| node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.", | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto data_shape = tensor.MutableShape(); | |||
| // input is dynamic, here use dim_num | |||
| if (data_shape.GetDims() != out_shape.GetDims()) { | |||
| GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.", | |||
| node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str()); | |||
| if (data_shape.GetDimNum() != out_shape.GetDimNum()) { | |||
| ref_out_tensor.SetUnknownDimNumShape(); | |||
| } else { | |||
| for (size_t j = 0; j < data_shape.GetDimNum(); ++j) { | |||
| if (data_shape.GetDim(j) != out_shape.GetDim(j)) { | |||
| if (data_shape.GetDim(j) != UNKNOWN_DIM) { | |||
| // if input data is fix shape, output is different, need_infer_again | |||
| need_infer_again = true; | |||
| } | |||
| data_shape.SetDim(j, UNKNOWN_DIM); | |||
| } | |||
| // set shape rang of while, if dim is unknown ,set shape range as {1,-1} | |||
| if (data_shape.GetDim(j) == UNKNOWN_DIM) { | |||
| data_shape_range.emplace_back(std::make_pair(1, UNKNOWN_DIM)); | |||
| } else { | |||
| data_shape_range.emplace_back(std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j))); | |||
| } | |||
| } | |||
| ref_out_tensor.SetShape(data_shape); | |||
| ref_out_tensor.SetShapeRange(data_shape_range); | |||
| } | |||
| } | |||
| } | |||
| //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||
| bool output_changed = false; | |||
| auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); | |||
| (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc,output_changed); | |||
| if (output_changed) { | |||
| changed_nodes.insert(node); | |||
| } | |||
| } | |||
| AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_infer_again); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||
| std::set<NodePtr> &changed_nodes) { | |||
| // check sub_graph shape. Get max for update. | |||
| for (size_t i = 0; i < ref_out_tensors.size(); ++i) { | |||
| if (ref_out_tensors[i].empty()) { | |||
| continue; | |||
| } | |||
| int64_t max_size = 0; | |||
| size_t max_shape_index = 0; | |||
| auto &ref_out_tensor = ref_out_tensors[i].at(0); | |||
| for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { | |||
| auto &tensor = ref_out_tensors[i].at(j); | |||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
| REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output", | |||
| node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output", | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto shape = tensor.MutableShape(); | |||
| int64_t size = 1; | |||
| for (auto dim : shape.GetDims()) { | |||
| if (dim != 0 && INT64_MAX / dim < size) { | |||
| REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(), | |||
| node->GetName().c_str()); | |||
| GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); | |||
| return PARAM_INVALID; | |||
| } | |||
| size *= dim; | |||
| } | |||
| if (size > max_size) { | |||
| max_size = size; | |||
| max_shape_index = j; | |||
| } | |||
| } | |||
| //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); | |||
| bool output_changed = false; | |||
| auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); | |||
| (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensors[i].at(max_shape_index)), output_desc, | |||
| output_changed); | |||
| if (output_changed) { | |||
| changed_nodes.insert(node); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||
| std::set<NodePtr> &changed_nodes) { | |||
| GELOGD("Enter update parent node shape for class branch op process"); | |||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | |||
| return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes); | |||
| } | |||
| // check sub_graph shape.If not same ,do unknown shape process | |||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].empty()) { | |||
| continue; | |||
| } | |||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||
| ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); | |||
| for (auto &tensor : ref_out_tensors[i]) { | |||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
| REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s", | |||
| node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto shape = tensor.MutableShape(); | |||
| if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||
| GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||
| shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
| ref_out_tensor_shape = GeShape(UNKNOWN_RANK); | |||
| break; | |||
| } | |||
| for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { | |||
| if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { | |||
| continue; | |||
| } | |||
| GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), | |||
| i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
| (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); | |||
| } | |||
| } | |||
| //(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||
| bool output_changed = false; | |||
| auto output_desc = node->GetOpDesc()->MutableOutputDesc(i); | |||
| (void)UpdateTensorDesc(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc, output_changed); | |||
| if (output_changed) { | |||
| changed_nodes.insert(node); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||
| if (is_unknown_graph) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | |||
| auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||
| for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
| auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc(); | |||
| if (peer_anchor_opdesc == nullptr) { | |||
| continue; | |||
| } | |||
| if (op_desc->GetId() < peer_anchor_opdesc->GetId() || peer_anchor_opdesc->GetType() == CONSTANT || | |||
| peer_anchor_opdesc->GetType() == CONSTANTOP) { | |||
| continue; | |||
| } | |||
| auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx()); | |||
| if (peer_input_desc == nullptr) { | |||
| continue; | |||
| } | |||
| bool changed = false; | |||
| auto ret = UpdateDescAttrForPeerInput(output_tensor, peer_input_desc, changed); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Failed to update peer tensor desc attr"); | |||
| GE_LOGE("[Update][PeerInputDesc] Failed to update peer tensor desc attr"); | |||
| return ret; | |||
| } | |||
| if (changed) { | |||
| changed_nodes.insert(peer_anchor->GetOwnerNode()); | |||
| } | |||
| } | |||
| } | |||
| PrintInOutTensorShape(node, "after_infer"); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus InferBasePass::UpdateDescAttrForPeerInput(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed){ | |||
| changed = false; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) { | |||
| if (!IsLogEnable(GE, DLOG_DEBUG)) { | |||
| return; | |||
| } | |||
| if (node == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid"); | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] node is null"); | |||
| return; | |||
| } | |||
| ge::OpDescPtr op_desc = node->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid"); | |||
| GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return ); | |||
| std::stringstream ss; | |||
| ss << "{"; | |||
| int32_t in_idx = 0; | |||
| int32_t out_idx = 0; | |||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||
| if (input_desc == nullptr) { | |||
| in_idx++; | |||
| continue; | |||
| } | |||
| if (in_idx > 0) { | |||
| ss << " "; | |||
| } | |||
| ss << "input_" << in_idx << " " | |||
| << "tensor: ["; | |||
| ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; | |||
| ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; | |||
| ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; | |||
| ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; | |||
| ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; | |||
| ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; | |||
| string range_str; | |||
| SerialShapeRange(input_desc, range_str); | |||
| ss << "(shape_range:" << range_str << ")]"; | |||
| in_idx++; | |||
| } | |||
| for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||
| if (output_desc == nullptr) { | |||
| out_idx++; | |||
| continue; | |||
| } | |||
| ss << " "; | |||
| ss << "output_" << out_idx << " " | |||
| << "tensor: ["; | |||
| ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),"; | |||
| ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),"; | |||
| ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),"; | |||
| ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),"; | |||
| ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),"; | |||
| ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),"; | |||
| string range_str; | |||
| SerialShapeRange(output_desc, range_str); | |||
| ss << "(shape_range:" << range_str << ")]"; | |||
| out_idx++; | |||
| } | |||
| ss << "}"; | |||
| GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str()); | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_PASSES_INFER_BASE_PASS_H_ | |||
| #define GE_GRAPH_PASSES_INFER_BASE_PASS_H_ | |||
| #include "graph/passes/base_pass.h" | |||
| namespace ge { | |||
| class InferBasePass : public BaseNodePass { | |||
| public: | |||
| Status Run(NodePtr &node) override; | |||
| graphStatus InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes); | |||
| void PrintInOutTensorShape(const NodePtr &node, const std::string &phase); | |||
| protected: | |||
| virtual graphStatus Infer(NodePtr &node) = 0; | |||
| virtual graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) = 0; | |||
| virtual graphStatus UpdateDescAttrForPeerInput(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed); | |||
| virtual bool NeedInfer(const NodePtr &node); | |||
| virtual void AnalyzeFailedInfo(const NodePtr &node); | |||
| virtual Status DoRepassForLoopNode(NodePtr &node); // only for test inferXBase, will be deleted | |||
| private: | |||
| void AddChangedNodesImmediateRepass(std::set<NodePtr> &changed_nodes); | |||
| bool ContainsSubgraph(const NodePtr &node); | |||
| graphStatus UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes); | |||
| graphStatus UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes); | |||
| graphStatus UpdateParentNodeForWhile(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||
| std::set<NodePtr> &changed_nodes); | |||
| graphStatus UpdateParentNodeForBranch(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||
| std::set<NodePtr> &changed_nodes); | |||
| graphStatus UpdateOutputForMultiBatch(NodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_out_tensors, | |||
| std::set<NodePtr> &changed_nodes); | |||
| graphStatus UpdateTensorDescToPeerInputs(NodePtr &node, std::set<NodePtr> &changed_nodes); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_INFER_BASE_PASS_H_ | |||
| @@ -19,15 +19,84 @@ | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "analyzer/analyzer.h" | |||
| #include "framework/common/util.h" | |||
| #include "graph/shape_refiner.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/node_utils.h" | |||
| #include "graph/common/omg_util.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "utils/tensor_utils.h" | |||
| #include "utils/type_utils.h" | |||
| #include "graph/debug/ge_util.h" | |||
| #include "graph/operator_factory_impl.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/node_utils.h" | |||
| #include "graph/utils/tensor_utils.h" | |||
| #include "graph/utils/type_utils.h" | |||
| namespace ge { | |||
| namespace { | |||
| const char *const kPreOpInputShapeRange = "_pre_op_in_range"; | |||
| thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void InferShapePass::ClearContextMap() { context_map.clear(); } | |||
| InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | |||
| const NodePtr &node) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "node is null"); | |||
| return nullptr; | |||
| } | |||
| InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create()); | |||
| if (inference_context == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "Failed to alloc InferenceContext, node:%s", node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Alloc][InferenceContext] failed."); | |||
| return nullptr; | |||
| } | |||
| auto all_in_data_anchors = node->GetAllInDataAnchors(); | |||
| std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size()); | |||
| std::vector<std::string> marks; | |||
| bool has_input_shapes_and_types = false; | |||
| for (const auto &in_anchor : all_in_data_anchors) { | |||
| const auto &out_anchor = in_anchor->GetPeerOutAnchor(); | |||
| if (out_anchor == nullptr) { | |||
| continue; | |||
| } | |||
| auto input_node = out_anchor->GetOwnerNode(); | |||
| if (input_node == nullptr) { | |||
| continue; | |||
| } | |||
| auto iter = context_map.find(input_node); | |||
| if (iter != context_map.end()) { | |||
| const auto &src_context = iter->second; | |||
| GE_IF_BOOL_EXEC(src_context == nullptr, REPORT_INNER_ERROR("E19999", "src_context is null."); | |||
| GELOGE(GRAPH_FAILED, "[Check][Param] src_context is null."); return nullptr); | |||
| GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(), | |||
| input_node->GetName().c_str()); | |||
| for (auto mark : src_context->GetMarks()) { | |||
| marks.push_back(mark); | |||
| } | |||
| auto output_idx = out_anchor->GetIdx(); | |||
| auto input_idx = in_anchor->GetIdx(); | |||
| auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); | |||
| if (output_idx < static_cast<int>(output_shape_and_type.size())) { | |||
| GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx, | |||
| node->GetName().c_str(), input_idx); | |||
| input_shapes_and_types[input_idx] = output_shape_and_type[output_idx]; | |||
| has_input_shapes_and_types = true; | |||
| } else { | |||
| GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, | |||
| output_shape_and_type.size()); | |||
| } | |||
| } | |||
| } | |||
| if (has_input_shapes_and_types) { | |||
| inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); | |||
| } | |||
| inference_context->SetMarks(marks); | |||
| return inference_context; | |||
| } | |||
| void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||
| desc_str += "["; | |||
| @@ -61,7 +130,8 @@ std::string GetInTensorInfoWithString(const ge::NodePtr &node) { | |||
| if (in_idx > 0) { | |||
| ss << " "; | |||
| } | |||
| ss << "input_" << in_idx << " " << "tensor: ["; | |||
| ss << "input_" << in_idx << " " | |||
| << "tensor: ["; | |||
| ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; | |||
| ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; | |||
| ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; | |||
| @@ -76,27 +146,238 @@ std::string GetInTensorInfoWithString(const ge::NodePtr &node) { | |||
| return ss.str(); | |||
| } | |||
| Status InferShapePass::Run(NodePtr &node) { | |||
| // kOptimizeAfterSubGraph exist means after subgraph | |||
| auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| // select INFERSHAPE failed info | |||
| auto graph = node->GetOwnerComputeGraph(); | |||
| GE_CHECK_NOTNULL(graph); | |||
| auto root_graph = ge::GraphUtils::FindRootGraph(graph); | |||
| GE_CHECK_NOTNULL(root_graph); | |||
| analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), | |||
| analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; | |||
| (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); | |||
| (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), | |||
| root_graph->GetGraphID()); | |||
| void InferShapePass::AnalyzeFailedInfo(const NodePtr &node) { | |||
| auto graph = node->GetOwnerComputeGraph(); | |||
| if (graph == nullptr) { | |||
| GELOGW("Owner compute graph of node %s is nullptr", node->GetName().c_str()); | |||
| } | |||
| auto root_graph = ge::GraphUtils::FindRootGraph(graph); | |||
| if (root_graph == nullptr) { | |||
| GELOGW("Root compute graph of node %s is nullptr", node->GetName().c_str()); | |||
| } | |||
| analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), analyzer::INFER_SHAPE, node, | |||
| "InferShapeFailed!"}; | |||
| (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); | |||
| (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), root_graph->GetGraphID()); | |||
| REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", node->GetName().c_str(), | |||
| node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); | |||
| } | |||
| REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", | |||
| node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); | |||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); | |||
| return GE_GRAPH_INFERSHAPE_FAILED; | |||
| graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) { | |||
| changed = false; | |||
| const auto &dst_dims = dst->GetShape().GetDims(); | |||
| const auto &src_dims = src->GetShape().GetDims(); | |||
| if (dst_dims == src_dims) { | |||
| changed = true; | |||
| } | |||
| dst = src; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus InferShapePass::UpdateDescAttrForPeerInput(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) { | |||
| changed = false; | |||
| if (dst->GetShape().GetDims() == src->GetShape().GetDims()) { | |||
| changed = true; | |||
| } | |||
| dst->SetOriginShape(src->GetOriginShape()); | |||
| dst->SetShape(src->GetShape()); | |||
| dst->SetDataType(src->GetDataType()); | |||
| dst->SetOriginDataType(src->GetOriginDataType()); | |||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
| (void)src->GetShapeRange(shape_range); | |||
| dst->SetShapeRange(shape_range); | |||
| ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->GetShape().GetDims().size())); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus InferShapePass::Infer(NodePtr &node) { | |||
| bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||
| auto opdesc = node->GetOpDesc(); | |||
| // some op can not infershape twice such as aipp | |||
| bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified"); | |||
| if (need_update_input) { | |||
| auto status = UpdateOpInputDesc(node); | |||
| if (status != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "update op input_desc failed! ret:%d, node:%s", status, node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Update][OpInputDesc] failed! ret:%d", status); | |||
| return status; | |||
| } | |||
| } | |||
| if (node->Verify() != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| PrintInOutTensorShape(node, "before_infershape"); | |||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||
| if (!is_unknown_graph) { | |||
| auto inference_context = CreateInferenceContext(context_map, node); | |||
| GE_CHECK_NOTNULL(inference_context); | |||
| GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); | |||
| op.SetInferenceContext(inference_context); | |||
| } | |||
| graphStatus status = CallInferShapeFunc(node, op); | |||
| if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | |||
| if (is_unknown_graph) { | |||
| PrintInOutTensorShape(node, "after_infershape when running"); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| UpdateInputOutputOriginAttr(node); | |||
| } else { | |||
| REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!is_unknown_graph) { | |||
| auto ctx_after_infer = op.GetInferenceContext(); | |||
| if (ctx_after_infer != nullptr) { | |||
| GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); | |||
| if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { | |||
| GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), | |||
| ctx_after_infer->GetMarks().size()); | |||
| (void)context_map.emplace(node, ctx_after_infer); | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| const auto &op_type = op_desc->GetType(); | |||
| auto ret = op_desc->CallInferFunc(op); | |||
| if (ret == GRAPH_PARAM_INVALID) { | |||
| // Op ir no infer func, try to get infer func from operator factory | |||
| auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | |||
| if (node_op.IsEmpty()) { | |||
| GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); | |||
| return ret; | |||
| } | |||
| GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); | |||
| auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||
| node_op.BreakConnect(); | |||
| if (temp_op_desc == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); | |||
| GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { | |||
| GELOGW("InferShapeAndType UpdateInputName failed"); | |||
| for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { | |||
| if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { | |||
| break; | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } | |||
| if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { | |||
| GELOGW("InferShapeAndType UpdateOutputName failed"); | |||
| } | |||
| op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); | |||
| ret = op_desc->CallInferFunc(op); | |||
| GELOGI("op CallInferFunc second. ret: %u", ret); | |||
| } | |||
| return ret; | |||
| } | |||
| void InferShapePass::UpdateInputOutputOriginAttr(NodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | |||
| auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||
| if (output_tensor == nullptr) { | |||
| continue; | |||
| } | |||
| if (output_tensor->MutableShape().GetDims().empty()) { | |||
| output_tensor->SetOriginShape(output_tensor->GetShape()); | |||
| } | |||
| ge::TensorUtils::SetRealDimCnt(*output_tensor, | |||
| static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims().size())); | |||
| output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||
| // set output origin shape range | |||
| std::vector<std::pair<int64_t, int64_t>> range; | |||
| (void)output_tensor->GetShapeRange(range); | |||
| output_tensor->SetOriginShapeRange(range); | |||
| GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", node->GetName().c_str(), | |||
| output_tensor->GetOriginShape().GetShapeSize(), | |||
| TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||
| TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||
| } | |||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||
| auto input_tensor = op_desc->MutableInputDesc(in_anchor->GetIdx()); | |||
| if (input_tensor == nullptr) { | |||
| continue; | |||
| } | |||
| // set input origin shape range | |||
| std::vector<std::pair<int64_t, int64_t>> range; | |||
| (void)input_tensor->GetShapeRange(range); | |||
| input_tensor->SetOriginShapeRange(range); | |||
| } | |||
| } | |||
| graphStatus InferShapePass::UpdateOpInputDesc(const ConstNodePtr &node_ptr) { | |||
| for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { | |||
| auto in_idx = in_anchor->GetIdx(); | |||
| auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_data_anchor == nullptr) { | |||
| continue; | |||
| } | |||
| auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); | |||
| if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { | |||
| continue; | |||
| } | |||
| int peer_out_idx = peer_out_data_anchor->GetIdx(); | |||
| auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx)); | |||
| // check shape and dtype continuity. do not stop process | |||
| auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx)); | |||
| if (in_desc == nullptr) { | |||
| continue; | |||
| } | |||
| auto in_shape = in_desc->MutableShape().GetDims(); | |||
| auto in_dtype = in_desc->GetDataType(); | |||
| auto peer_out_shape = peer_out_desc->MutableShape().GetDims(); | |||
| auto peer_out_dtype = peer_out_desc->GetDataType(); | |||
| if (peer_out_dtype != in_dtype) { | |||
| GELOGW( | |||
| "current node [%s] [%d]\'th in_dtype is [%s].peer output node [%s] [%d]\'th " | |||
| "output_dtype is [%s].The two dtype should be same! Please check graph and fix it", | |||
| node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), | |||
| peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); | |||
| } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { | |||
| string in_shape_str = " "; //Serial(in_shape); | |||
| string peer_out_shape_str = " "; //Serial(peer_out_shape); | |||
| GELOGW( | |||
| "current node [%s] [%d]\'th in_shape is [%s].peer output node [%s] [%d]\'th " | |||
| "output_shape is [%s].The two shape should be same! Please check graph and fix it", | |||
| node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx, | |||
| peer_out_shape_str.c_str()); | |||
| } | |||
| // refresh current node input desc | |||
| in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); | |||
| in_desc->SetShape(peer_out_desc->MutableShape()); | |||
| in_desc->SetDataType(peer_out_desc->GetDataType()); | |||
| in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); | |||
| if (peer_out_desc->MutableShape().GetDims() != UNKNOWN_RANK) { | |||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
| (void)peer_out_desc->GetShapeRange(shape_range); | |||
| in_desc->SetShapeRange(shape_range); | |||
| } | |||
| std::vector<int64_t> pre_op_in_range; | |||
| if (ge::AttrUtils::GetListInt(*peer_out_desc, kPreOpInputShapeRange, pre_op_in_range)) { | |||
| (void)ge::AttrUtils::SetListInt(*in_desc, kPreOpInputShapeRange, pre_op_in_range); | |||
| } | |||
| ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->MutableShape().GetDims().size())); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| Status InferShapePass::DoRepassForLoopNode(NodePtr &node) { | |||
| GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); | |||
| bool need_repass = false; | |||
| auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); | |||
| @@ -148,13 +429,13 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||
| std::string node_type; | |||
| GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original node type failed."); | |||
| if (kNextIterationOpTypes.count(node_type) > 0) { | |||
| return RePassNode(kMergeOpTypes); // Re-Pass Merge | |||
| return RePassNode(kMergeOpTypes); // Re-Pass Merge | |||
| } | |||
| if (kMergeOpTypes.count(node_type) > 0) { | |||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||
| return RePassNode(kSwitchOpTypes); // Re-Pass Switch | |||
| return RePassNode(kSwitchOpTypes); // Re-Pass Switch | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -162,12 +443,111 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||
| if (kSwitchOpTypes.count(node_type) > 0) { | |||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit | |||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit | |||
| } else { | |||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit | |||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY | |||
| graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| InferShapePass pass; | |||
| std::set<NodePtr> unused_changed_nodes; | |||
| return pass.InferAndUpdate(node, true, unused_changed_nodes); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY | |||
| graphStatus InferShapePass::InferShapeAndType(NodePtr &node, bool before_subgraph) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| InferShapePass pass; | |||
| std::set<NodePtr> unused_changed_nodes; | |||
| return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes); | |||
| } | |||
| graphStatus InferShapeForRunning::Infer(NodePtr &node) { | |||
| auto opdesc = node->GetOpDesc(); | |||
| vector<ge::DataType> temp_dtype; | |||
| for (auto &tensor_desc : opdesc->GetAllOutputsDescPtr()) { | |||
| temp_dtype.emplace_back(tensor_desc->GetDataType()); | |||
| } | |||
| PrintInOutTensorShape(node, "before_infershape when running"); | |||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||
| graphStatus status = CallInferShapeFuncForRunning(node, op); | |||
| if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | |||
| // ensure the dtype is not changed after infershape in running | |||
| auto after_opdesc = node->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(after_opdesc == nullptr, REPORT_INNER_ERROR("E19999", "param node has no opdesc, check invalid."); | |||
| GELOGE(GRAPH_FAILED, "[Get][OpDesc] after_opdesc is null."); return GRAPH_FAILED); | |||
| auto all_output_tensor = after_opdesc->GetAllOutputsDescPtr(); | |||
| for (size_t i = 0; i < all_output_tensor.size(); ++i) { | |||
| if (all_output_tensor.at(i)->GetDataType() != temp_dtype[i]) { | |||
| GELOGD("Op %s output %zu need reset dtype,original dtype is %s, new dtype is %s", node->GetName().c_str(), i, | |||
| TypeUtils::DataTypeToSerialString(all_output_tensor.at(i)->GetDataType()).c_str(), | |||
| TypeUtils::DataTypeToSerialString(temp_dtype[i]).c_str()); | |||
| all_output_tensor.at(i)->SetDataType(temp_dtype[i]); | |||
| } | |||
| } | |||
| PrintInOutTensorShape(node, "after_infershape when running"); | |||
| return GRAPH_SUCCESS; | |||
| } else { | |||
| REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| graphStatus InferShapeForRunning::CallInferShapeFuncForRunning(NodePtr &node, Operator &op) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| const auto &op_type = op_desc->GetType(); | |||
| // Create InferenceContext to avoid null pointer access. | |||
| const static std::set<std::string> force_context_op_types{"Enter", "Switch", "RefSwitch"}; | |||
| if (force_context_op_types.count(op_type) > 0) { | |||
| GELOGD("Set InferenceContext for node [%s]", op_desc->GetName().c_str()); | |||
| op.SetInferenceContext(std::shared_ptr<InferenceContext>(InferenceContext::Create())); | |||
| } | |||
| // Get infer func and execute | |||
| auto ret = op_desc->CallInferFunc(op); | |||
| if (ret == GRAPH_PARAM_INVALID) { | |||
| GELOGD("NodeUtils::GetNodeType return value is: [%s]", NodeUtils::GetNodeType(*node).c_str()); | |||
| auto origin_type = NodeUtils::GetNodeType(*node); | |||
| auto infer_func = ge::OperatorFactoryImpl::GetInferShapeFunc(origin_type); | |||
| if (infer_func == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "Failed to Get InferFunc. type is %s", origin_type.c_str()); | |||
| GELOGE(GRAPH_FAILED, "[Get][InferFunc] failed. type is %s", origin_type.c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| op_desc->AddInferFunc(infer_func); | |||
| ret = op_desc->CallInferFunc(op); | |||
| GELOGI("op CallInferFunc second. ret: %u", ret); | |||
| } | |||
| return ret; | |||
| } | |||
| graphStatus InferShapeForRunning::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) { | |||
| changed = false; | |||
| const auto &dst_dims = dst->GetShape().GetDims(); | |||
| const auto &src_dims = src->GetShape().GetDims(); | |||
| if (dst_dims == src_dims) { | |||
| changed = true; | |||
| } | |||
| dst = src; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY | |||
| graphStatus InferShapeForRunning::InferShapeAndTypeForRunning(NodePtr &node, bool before_subgraph) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| InferShapeForRunning pass; | |||
| std::set<NodePtr> unused_changed_nodes; | |||
| return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes); | |||
| } | |||
| } // namespace ge | |||
| @@ -17,22 +17,39 @@ | |||
| #ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | |||
| #define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | |||
| #include "graph/passes/base_pass.h" | |||
| #include "graph/passes/infer_base_pass.h" | |||
| namespace ge { | |||
| class InferShapePass : public BaseNodePass { | |||
| class InferShapePass : public InferBasePass { | |||
| public: | |||
| /// | |||
| /// Entry of the InferShapePass optimizer | |||
| /// @param [in] graph: Input ComputeGraph | |||
| /// @return SUCCESS: Execution succeed | |||
| /// @return OTHERS: Execution failed | |||
| /// @author | |||
| /// | |||
| Status Run(ge::NodePtr &node) override; | |||
| static void ClearContextMap(); | |||
| graphStatus Infer(NodePtr &node) override; | |||
| graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) override; | |||
| graphStatus UpdateDescAttrForPeerInput(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) override; | |||
| void AnalyzeFailedInfo(const NodePtr &node) override; | |||
| static graphStatus InferShapeAndType(NodePtr &node); // temp: visible static func is in cur pass | |||
| static graphStatus InferShapeAndType(NodePtr &node, bool before_subgraph); // temp: visible static func is in cur pass | |||
| private: | |||
| graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); | |||
| void UpdateInputOutputOriginAttr(NodePtr &node); | |||
| graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr); // maybe useless, just test infer_shape | |||
| Status DoRepassForLoopNode(NodePtr &node) override; // only for test inferXBase, will be deleted | |||
| Status RePassLoopNode(const NodePtr &node); // old repass logic, will be deleted | |||
| }; | |||
| class InferShapeForRunning : public InferBasePass { | |||
| public: | |||
| graphStatus Infer(NodePtr &node) override; | |||
| graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr dst, bool &changed) override; | |||
| static graphStatus InferShapeAndTypeForRunning(NodePtr &node, bool before_subgraph); // temp: visible static func | |||
| private: | |||
| Status RePassLoopNode(const NodePtr &node); | |||
| graphStatus CallInferShapeFuncForRunning(NodePtr &node, Operator &op); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | |||
| @@ -22,6 +22,8 @@ | |||
| #include "common/math/math_util.h" | |||
| #include "hybrid/node_executor/node_executor.h" | |||
| #include "graph/passes/infershape_pass.h" // test new infershape pass | |||
| namespace ge { | |||
| namespace { | |||
| const int kAlignment = 32; | |||
| @@ -71,7 +73,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||
| GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); | |||
| { | |||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | |||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | |||
| GE_CHK_STATUS_RET(InferShapeForRunning::InferShapeAndTypeForRunning(const_cast<NodePtr &>(node_item.node), true), | |||
| "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); | |||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | |||
| } | |||
| @@ -173,7 +175,7 @@ Status ShapeInferenceEngine::InferShapeForSubgraph(const NodeItem &node_item, co | |||
| for (auto &node : fused_subgraph.nodes) { | |||
| GELOGD("[%s] Start to invoke InferShapeAndType", node->GetName().c_str()); | |||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndType(node)); | |||
| GE_CHK_STATUS_RET(InferShapePass::InferShapeAndType(const_cast<NodePtr &>(node))); | |||
| GELOGD("[%s] Done invoking InferShapeAndType", node->GetName().c_str()); | |||
| GE_CHK_STATUS_RET(UpdatePeerNodeShape(*node), | |||
| "[Update][PeerNodeShape] failed for [%s].", node->GetName().c_str()); | |||
| @@ -218,7 +218,9 @@ set(COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | |||
| #"${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" | |||
| @@ -477,7 +479,7 @@ set(GRAPH_BUILD_COMMON_SRC_FILES | |||
| ) | |||
| set(GRAPH_PASS_COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/pass_manager.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/pass_manager.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/base_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/variable_prepare_op_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/variable_ref_delete_op_pass.cc" | |||
| @@ -531,7 +533,9 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" | |||
| #"${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" | |||
| "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | |||
| "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" | |||
| "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | |||
| @@ -706,6 +710,7 @@ set(PASS_TEST_FILES | |||
| "graph/passes/net_output_pass_unittest.cc" | |||
| "graph/passes/no_use_reshape_remove_pass_unittest.cc" | |||
| "graph/passes/infershape_pass_unittest.cc" | |||
| #"graph/passes/infer_value_range_pass_unittest.cc" | |||
| "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | |||
| "graph/passes/multi_batch_clone_pass_unittest.cc" | |||
| "graph/passes/replace_with_empty_const_pass_unittest.cc" | |||