From: @zhangxiaokun9 Reviewed-by: @ji_chen,@wqtshg Signed-off-by: @ji_chentags/v1.3.0
| @@ -307,6 +307,7 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/merge_to_stream_merge_pass.cc" | "graph/passes/merge_to_stream_merge_pass.cc" | ||||
| "graph/passes/merge_input_memcpy_pass.cc" | "graph/passes/merge_input_memcpy_pass.cc" | ||||
| "graph/passes/switch_to_stream_switch_pass.cc" | "graph/passes/switch_to_stream_switch_pass.cc" | ||||
| "graph/passes/mark_force_unknown_for_cond_pass.cc" | |||||
| "graph/passes/attach_stream_label_pass.cc" | "graph/passes/attach_stream_label_pass.cc" | ||||
| "graph/passes/switch_dead_branch_elimination.cc" | "graph/passes/switch_dead_branch_elimination.cc" | ||||
| "graph/passes/replace_transshape_pass.cc" | "graph/passes/replace_transshape_pass.cc" | ||||
| @@ -584,6 +585,7 @@ set(INFER_SRC_LIST | |||||
| "graph/passes/merge_to_stream_merge_pass.cc" | "graph/passes/merge_to_stream_merge_pass.cc" | ||||
| "graph/passes/merge_input_memcpy_pass.cc" | "graph/passes/merge_input_memcpy_pass.cc" | ||||
| "graph/passes/switch_to_stream_switch_pass.cc" | "graph/passes/switch_to_stream_switch_pass.cc" | ||||
| "graph/passes/mark_force_unknown_for_cond_pass.cc" | |||||
| "graph/passes/attach_stream_label_pass.cc" | "graph/passes/attach_stream_label_pass.cc" | ||||
| "graph/passes/multi_batch_pass.cc" | "graph/passes/multi_batch_pass.cc" | ||||
| "graph/passes/multi_batch_clone_pass.cc" | "graph/passes/multi_batch_clone_pass.cc" | ||||
| @@ -65,6 +65,7 @@ | |||||
| #include "graph/passes/merge_pass.h" | #include "graph/passes/merge_pass.h" | ||||
| #include "graph/passes/merge_input_memcpy_pass.h" | #include "graph/passes/merge_input_memcpy_pass.h" | ||||
| #include "graph/passes/merge_to_stream_merge_pass.h" | #include "graph/passes/merge_to_stream_merge_pass.h" | ||||
| #include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||||
| #include "graph/passes/multi_batch_pass.h" | #include "graph/passes/multi_batch_pass.h" | ||||
| #include "graph/passes/next_iteration_pass.h" | #include "graph/passes/next_iteration_pass.h" | ||||
| #include "graph/passes/permute_pass.h" | #include "graph/passes/permute_pass.h" | ||||
| @@ -2535,7 +2536,9 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| // the prune pass should between SwitchPass and SwitchToStreamSwitchPass | // the prune pass should between SwitchPass and SwitchToStreamSwitchPass | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)) | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); | |||||
| auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass)); | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) | ||||
| GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
| @@ -0,0 +1,125 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "mark_force_unknown_for_cond_pass.h" | |||||
| #include <queue> | |||||
| #include "graph/common/omg_util.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| const std::set<std::string> kMergeOpTypes{ MERGE, REFMERGE }; | |||||
| const std::set<std::string> kSwitchOpTypes{ SWITCH, REFSWITCH }; | |||||
| const std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||||
| inline bool IsMergeInLoop(const NodePtr &node) { | |||||
| std::string node_type; | |||||
| (void)GetOriginalType(node, node_type); | |||||
| return kLoopMergeInputs.count(node_type) > 0; | |||||
| } | |||||
| } | |||||
| Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
| GELOGD("MarkForceUnknownForCondPass Enter"); | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | |||||
| if (kMergeOpTypes.count(node_type) == 0) { | |||||
| continue; | |||||
| } | |||||
| const auto op_desc = node->GetOpDesc(); | |||||
| if (!op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) && !IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { | |||||
| GELOGI("Merge[%s] has known shape, no need check switch", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| const auto &all_in_nodes = node->GetInDataNodes(); | |||||
| if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsMergeInLoop)) { | |||||
| continue; // LoopCond marked in NextIterationPass. | |||||
| } | |||||
| MarkUnknownForSwitch(node); | |||||
| } | |||||
| GELOGD("MarkForceUnknownForCondPass Leave"); | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @brief Mark force unknown shape for Switch node | |||||
| /// @param [in] merge node | |||||
| /// @return | |||||
| /// | |||||
| void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) { | |||||
| // Switch --> {Switch --> Merge} --> Merge | |||||
| std::vector<NodePtr> switch_group; | |||||
| std::unordered_set<NodePtr> nodes_seen; | |||||
| std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | |||||
| while (!search_queue.empty()) { | |||||
| const auto dst_node = search_queue.front().first; | |||||
| const auto dst_span = search_queue.front().second; | |||||
| search_queue.pop(); | |||||
| // Switch --> Identity --> Constant | |||||
| for (const auto &in_node : dst_node->GetInControlNodes()) { | |||||
| if (nodes_seen.count(in_node) > 0) { | |||||
| GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| nodes_seen.insert(in_node); | |||||
| if (in_node->GetType() == IDENTITY) { | |||||
| GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), | |||||
| in_node->GetName().c_str(), dst_span); | |||||
| search_queue.push({in_node, dst_span}); | |||||
| } | |||||
| } | |||||
| for (const auto &in_node : dst_node->GetInDataNodes()) { | |||||
| if (nodes_seen.count(in_node) > 0) { | |||||
| GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| nodes_seen.insert(in_node); | |||||
| std::string node_type; | |||||
| (void)GetOriginalType(in_node, node_type); | |||||
| GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | |||||
| in_node->GetName().c_str(), dst_span); | |||||
| if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | |||||
| if (dst_span > 0) { | |||||
| search_queue.push({in_node, dst_span - 1}); | |||||
| } else { | |||||
| switch_group.emplace_back(in_node); | |||||
| } | |||||
| } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | |||||
| search_queue.push({in_node, dst_span + 1}); | |||||
| } else { | |||||
| search_queue.push({in_node, dst_span}); | |||||
| } | |||||
| } | |||||
| } | |||||
| for (const auto &n : switch_group) { | |||||
| MarkForceUnknownShape(n, true); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ | |||||
| #include "inc/graph_pass.h" | |||||
| namespace ge { | |||||
| class MarkForceUnknownForCondPass : public GraphPass { | |||||
| public: | |||||
| Status Run(ComputeGraphPtr graph); | |||||
| private: | |||||
| /// | |||||
| /// @brief Mark force unknown shape for Switch node | |||||
| /// @param [in] merge node | |||||
| /// @return | |||||
| /// | |||||
| void MarkUnknownForSwitch(const NodePtr &node); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_ | |||||
| @@ -16,18 +16,11 @@ | |||||
| #include "graph/passes/merge_input_memcpy_pass.h" | #include "graph/passes/merge_input_memcpy_pass.h" | ||||
| #include <queue> | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const std::set<std::string> kLoopMergeInputs{ | |||||
| ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION | |||||
| }; | |||||
| } | |||||
| Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { | Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { | ||||
| GELOGD("MergeInputMemcpyPass Enter"); | GELOGD("MergeInputMemcpyPass Enter"); | ||||
| std::unordered_map<NodePtr, std::vector<NodePtr>> switch_groups; | std::unordered_map<NodePtr, std::vector<NodePtr>> switch_groups; | ||||
| @@ -41,10 +34,8 @@ Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)), | GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)), | ||||
| "Merge add memcpy node failed."); | "Merge add memcpy node failed."); | ||||
| CollectSwitchGroup(node, switch_groups); | |||||
| } | } | ||||
| MarkUnknownForSwitch(switch_groups); | |||||
| GELOGD("MergeInputMemcpyPass Leave"); | GELOGD("MergeInputMemcpyPass Leave"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -114,94 +105,4 @@ NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph | |||||
| return graph->AddNode(op_desc); | return graph->AddNode(op_desc); | ||||
| } | } | ||||
| /// | |||||
| /// @brief Mark force unknown shape for Switch node | |||||
| /// @param [in] merge node | |||||
| /// @param [out] switch_groups | |||||
| /// @return | |||||
| /// | |||||
| void MergeInputMemcpyPass::CollectSwitchGroup(const NodePtr &node, | |||||
| std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| const auto &src_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (src_out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| std::string node_type; | |||||
| GetOriginalType(src_out_anchor->GetOwnerNode(), node_type); | |||||
| if (kLoopMergeInputs.count(node_type) > 0) { | |||||
| return; | |||||
| } | |||||
| } | |||||
| // Switch --> {Switch --> Merge} --> Merge | |||||
| std::queue<std::pair<NodePtr, uint32_t>> search_queue; | |||||
| search_queue.push({node, 0}); | |||||
| std::vector<NodePtr> &switch_group = switch_groups[node]; | |||||
| while (!search_queue.empty()) { | |||||
| const auto dst_node = search_queue.front().first; | |||||
| const auto dst_span = search_queue.front().second; | |||||
| search_queue.pop(); | |||||
| // Switch --> Identity --> Constant | |||||
| for (const auto &in_ctrl_node : dst_node->GetInControlNodes()) { | |||||
| if (in_ctrl_node->GetType() == IDENTITY) { | |||||
| GELOGD("Travel node: %s, In control: %s, span is: %u", | |||||
| dst_node->GetName().c_str(), in_ctrl_node->GetName().c_str(), dst_span); | |||||
| search_queue.push({in_ctrl_node, dst_span}); | |||||
| } | |||||
| } | |||||
| for (const auto &in_data_node : dst_node->GetInDataNodes()) { | |||||
| std::string node_type; | |||||
| GetOriginalType(in_data_node, node_type); | |||||
| GELOGD("Travel node: %s, %s node: %s, span is: %u", | |||||
| dst_node->GetName().c_str(), node_type.c_str(), in_data_node->GetName().c_str(), dst_span); | |||||
| if (node_type == SWITCH || node_type == REFSWITCH) { | |||||
| if (dst_span > 0) { | |||||
| search_queue.push({in_data_node, dst_span - 1}); | |||||
| } else { | |||||
| switch_group.emplace_back(in_data_node); | |||||
| } | |||||
| } else if (node_type == MERGE || node_type == REFMERGE) { | |||||
| search_queue.push({in_data_node, dst_span + 1}); | |||||
| } else { | |||||
| search_queue.push({in_data_node, dst_span}); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { | |||||
| GELOGI("Mark [%s] as for unknown shape, switch groups: %zu", node->GetName().c_str(), switch_groups.size()); | |||||
| MarkForceUnknownShape(node, true); | |||||
| for (const auto &n : switch_group) { | |||||
| MarkForceUnknownShape(n, true); | |||||
| } | |||||
| } | |||||
| } | |||||
| void MergeInputMemcpyPass::MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| }; | |||||
| for (const auto &item : switch_groups) { | |||||
| const auto &node = item.first; | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { | |||||
| continue; | |||||
| } | |||||
| const std::vector<NodePtr> &switch_group = item.second; | |||||
| if (std::any_of(switch_group.begin(), switch_group.end(), callback)) { | |||||
| GELOGI("Mark [%s] as force unknown shape, switch nodes: %zu", node->GetName().c_str(), switch_group.size()); | |||||
| MarkForceUnknownShape(node, true); | |||||
| for (const auto &n : switch_group) { | |||||
| MarkForceUnknownShape(n, true); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -44,21 +44,6 @@ class MergeInputMemcpyPass : public GraphPass { | |||||
| /// | /// | ||||
| NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | ||||
| const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); | const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); | ||||
| /// | |||||
| /// @brief Mark force unknown shape for Switch node | |||||
| /// @param [in] merge node | |||||
| /// @param [out] switch_groups | |||||
| /// @return | |||||
| /// | |||||
| void CollectSwitchGroup(const NodePtr &node, std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups); | |||||
| /// | |||||
| /// @brief Mark force unknown shape for Switch node | |||||
| /// @param [in] switch_groups | |||||
| /// @return | |||||
| /// | |||||
| void MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_ | #endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_ | ||||
| @@ -239,6 +239,7 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/mark_force_unknown_for_cond_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" | ||||
| @@ -703,6 +704,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/net_output_pass_unittest.cc" | "graph/passes/net_output_pass_unittest.cc" | ||||
| "graph/passes/no_use_reshape_remove_pass_unittest.cc" | "graph/passes/no_use_reshape_remove_pass_unittest.cc" | ||||
| "graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
| "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | |||||
| "graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
| "graph/passes/replace_with_empty_const_pass_unittest.cc" | "graph/passes/replace_with_empty_const_pass_unittest.cc" | ||||
| "graph/passes/link_gen_mask_nodes_pass_unittest.cc" | "graph/passes/link_gen_mask_nodes_pass_unittest.cc" | ||||
| @@ -0,0 +1,230 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/operator_factory.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "graph_builder_utils.h" | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| class UtestMarkForceUnknownForCondPass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
| TensorUtils::SetSize(tensor, 512); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(1024); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(1024); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| op_desc->SetWorkspace({}); | |||||
| op_desc->SetWorkspaceBytes({}); | |||||
| op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
| const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; | |||||
| op_desc->AddInferFunc(stub_func); | |||||
| op_desc->AddInferFormatFunc(stub_func); | |||||
| op_desc->AddVerifierFunc(stub_func); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
| /******************************************************************************* | |||||
| * Exit Identify | |||||
| * \ / \. | |||||
| * \ / \. | |||||
| * Switch Add | |||||
| * / | | | |||||
| * / | | | |||||
| * / | | | |||||
| * LoopCond | | | |||||
| * \ | | | |||||
| * \ | | | |||||
| * \ | | | |||||
| * Less | | | |||||
| * \ | NextIteration | |||||
| * \ | | | |||||
| * \ | | | |||||
| * Merge <---------| | |||||
| * | | |||||
| * | | |||||
| * Enter | |||||
| ******************************************************************************/ | |||||
| auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
| auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||||
| auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||||
| auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
| auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | |||||
| auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||||
| auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | |||||
| auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
| auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | |||||
| auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | |||||
| auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
| merge = merge1; | |||||
| } | |||||
| static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
| /******************************************************************************* | |||||
| * NetOutput | |||||
| * | | |||||
| * | | |||||
| * Merge | |||||
| * / \. | |||||
| * / \. | |||||
| * / \. | |||||
| * Add Sub | |||||
| * | \ | \. | |||||
| * | \ | \. | |||||
| * | \ | Const | |||||
| * | \ | \. | |||||
| * | \ | Identify | |||||
| * | \ | | | |||||
| * Switch Switch Switch Switch | |||||
| * | | | | | | |||||
| * | | | | | | |||||
| * x y Cond z | |||||
| ******************************************************************************/ | |||||
| auto data1 = CreateNode(*graph, "data_x", DATA, 1, 1); | |||||
| auto data2 = CreateNode(*graph, "data_y", DATA, 1, 1); | |||||
| auto data3 = CreateNode(*graph, "data_z", DATA, 1, 1); | |||||
| auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
| auto switch1 = CreateNode(*graph, "switch_x", SWITCH, 2, 2); | |||||
| auto switch2 = CreateNode(*graph, "switch_y", SWITCH, 2, 2); | |||||
| auto switch3 = CreateNode(*graph, "switch_z", SWITCH, 2, 2); | |||||
| auto switch4 = CreateNode(*graph, "switch_i", SWITCH, 2, 2); | |||||
| auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
| auto sub1 = CreateNode(*graph, "add", SUB, 2, 1); | |||||
| auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | |||||
| auto const1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||||
| auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data3->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch4->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch4->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch2->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), sub1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch4->GetOutDataAnchor(0), ident1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(ident1->GetOutControlAnchor(), const1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(const1->GetOutDataAnchor(0), sub1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(add1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(sub1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
| merge = merge1; | |||||
| } | |||||
| TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| NodePtr merge; | |||||
| CreateLoopGraph(graph, merge); | |||||
| AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
| MarkForceUnknownForCondPass mark_force_unknown_pass; | |||||
| EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | |||||
| } | |||||
| TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| NodePtr merge; | |||||
| CreateCondGraph(graph, merge); | |||||
| MarkForceUnknownForCondPass mark_force_unknown_pass; | |||||
| EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip known shape merge | |||||
| } | |||||
| TEST_F(UtestMarkForceUnknownForCondPass, mark_unknown_shape_merge) { | |||||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| NodePtr merge; | |||||
| CreateCondGraph(graph, merge); | |||||
| auto tensor_desc = merge->GetOpDesc()->GetOutputDesc(0); | |||||
| tensor_desc.SetShape(GeShape({-1})); // Set for unknown. | |||||
| merge->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
| MarkForceUnknownForCondPass mark_force_unknown_pass; | |||||
| EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||