Merge pull request !1876 from 张晓昆/mastertags/v1.5.1
| @@ -16,8 +16,6 @@ | |||
| #include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||
| #include <queue> | |||
| #include "graph/utils/node_utils.h" | |||
| #include "graph/common/omg_util.h" | |||
| @@ -26,17 +24,7 @@ namespace { | |||
| inline bool IsMergeInLoop(const NodePtr &node) { | |||
| const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||
| std::string node_type; | |||
| (void)GetOriginalType(node, node_type); | |||
| return kLoopMergeInputs.count(node_type) > 0; | |||
| } | |||
| inline bool IsSwitchInLoop(const NodePtr &node) { | |||
| const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||
| std::string node_type; | |||
| (void)GetOriginalType(node, node_type); | |||
| return kLoopSwitchInputs.count(node_type) > 0; | |||
| return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; | |||
| } | |||
| } | |||
| @@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
| GELOGD("MarkForceUnknownForCondPass Enter"); | |||
| std::map<NodePtr, std::vector<NodePtr>> switch_groups; | |||
| for (const auto &node : graph->GetDirectNode()) { | |||
| std::string node_type; | |||
| GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | |||
| "[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); | |||
| if (kMergeOpTypes.count(node_type) == 0) { | |||
| if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { | |||
| continue; | |||
| } | |||
| @@ -64,6 +49,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
| return SUCCESS; | |||
| } | |||
| /// | |||
| /// @brief Deal with Switch node for LoopCond | |||
| /// @param [in] Switch node | |||
| /// @param [in] dest span | |||
| /// @param [out] Search queue | |||
| /// @return true: Switch In while loop / false: Not in while Loop. | |||
| /// | |||
| bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, | |||
| std::queue<std::pair<NodePtr, uint32_t>> &search_queue) { | |||
| /// LoopCond --->\. | |||
| /// \. | |||
| /// Enter-----------+ \. | |||
| /// +--> Merge --> Switch --> Exit | |||
| /// NextIteration---+ | |||
| const auto is_loop_op = [](const NodePtr &n) { | |||
| return NodeUtils::GetNodeType(n) == LOOPCOND; | |||
| }; | |||
| const auto is_exit_op = [](const NodePtr &n) { | |||
| return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; | |||
| }; | |||
| const auto src_nodes = node->GetInAllNodes(); | |||
| const auto dst_nodes = node->GetOutAllNodes(); | |||
| if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && | |||
| std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { | |||
| return false; | |||
| } | |||
| for (const auto &m : src_nodes) { | |||
| if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { | |||
| for (const auto &n : m->GetInAllNodes()) { | |||
| if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { | |||
| continue; | |||
| } | |||
| search_queue.push({n, dst_span}); | |||
| GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), | |||
| n->GetName().c_str(), dst_span); | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| /// | |||
| /// @brief Mark force unknown shape for Switch node | |||
| /// @param [in] merge node | |||
| @@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
| /// | |||
| void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | |||
| // Switch --> {Switch --> Merge} --> Merge | |||
| GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); | |||
| std::unordered_set<NodePtr> nodes_seen; | |||
| std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | |||
| while (!search_queue.empty()) { | |||
| @@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
| const auto dst_span = search_queue.front().second; | |||
| search_queue.pop(); | |||
| // Switch --> Identity --> Constant | |||
| for (const auto &in_node : dst_node->GetInControlNodes()) { | |||
| if (nodes_seen.count(in_node) > 0) { | |||
| GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||
| continue; | |||
| } | |||
| nodes_seen.insert(in_node); | |||
| if (in_node->GetType() == IDENTITY) { | |||
| GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), | |||
| in_node->GetName().c_str(), dst_span); | |||
| search_queue.push({in_node, dst_span}); | |||
| } | |||
| } | |||
| for (const auto &in_node : dst_node->GetInDataNodes()) { | |||
| for (const auto &in_node : dst_node->GetInAllNodes()) { | |||
| if (nodes_seen.count(in_node) > 0) { | |||
| GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||
| continue; | |||
| } | |||
| nodes_seen.insert(in_node); | |||
| std::string node_type; | |||
| (void)GetOriginalType(in_node, node_type); | |||
| const std::string node_type = NodeUtils::GetNodeType(in_node); | |||
| GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | |||
| in_node->GetName().c_str(), dst_span); | |||
| if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | |||
| if (DealAsLoopSwitch(in_node, dst_span, search_queue)) { | |||
| continue; | |||
| } | |||
| if (dst_span > 0) { | |||
| search_queue.push({in_node, dst_span - 1}); | |||
| } else { | |||
| const auto &all_in_nodes = in_node->GetInDataNodes(); | |||
| if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||
| GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||
| in_node->GetName().c_str()); | |||
| } else { | |||
| switch_group.emplace_back(in_node); | |||
| } | |||
| switch_group.emplace_back(in_node); | |||
| } | |||
| } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | |||
| search_queue.push({in_node, dst_span + 1}); | |||
| @@ -19,12 +19,23 @@ | |||
| #include "inc/graph_pass.h" | |||
| #include <queue> | |||
| namespace ge { | |||
| class MarkForceUnknownForCondPass : public GraphPass { | |||
| public: | |||
| Status Run(ComputeGraphPtr graph); | |||
| private: | |||
| /// | |||
| /// @brief Deal with Switch node for LoopCond | |||
| /// @param [in] Switch node | |||
| /// @param [in] dest span | |||
| /// @param [out] Search queue | |||
| /// @return true: Switch In while loop / false: Not in while Loop. | |||
| /// | |||
| bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> &search_queue); | |||
| /// | |||
| /// @brief Mark force unknown shape for Switch node | |||
| /// @param [in] merge node | |||
| @@ -24,7 +24,9 @@ using std::string; | |||
| namespace ge { | |||
| namespace { | |||
| const int64_t kLoopType = 1; | |||
| constexpr int64_t kLoopType = 1; | |||
| constexpr uint8_t kMaxTransOp = 3; | |||
| constexpr uint8_t kTransOpIoSize = 1; | |||
| } | |||
| Status NextIterationPass::Run(ComputeGraphPtr graph) { | |||
| @@ -287,18 +289,25 @@ void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, i | |||
| std::string node_type; | |||
| for (const auto &switch_node : loop_group.switch_nodes) { | |||
| SetControlFlowGroup(switch_node, group_index); | |||
| for (const auto &node : switch_node->GetOutDataNodes()) { | |||
| (void)GetOriginalType(node, node_type); | |||
| if (kExitOpTypes.count(node_type) > 0) { | |||
| SetControlFlowGroup(node, group_index); | |||
| } else { | |||
| // For: Switch -> Cast -> Exit | |||
| for (const auto &n : node->GetOutDataNodes()) { | |||
| (void)GetOriginalType(n, node_type); | |||
| if (kExitOpTypes.count(node_type) > 0) { | |||
| SetControlFlowGroup(n, group_index); | |||
| } | |||
| for (auto node : switch_node->GetOutDataNodes()) { | |||
| // Switch --> Exit | |||
| // Switch --> Cast --> Exit | |||
| // Switch --> TransData --> Cast --> Exit | |||
| for (uint8_t i = 0; i < kMaxTransOp; ++i) { | |||
| if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) { | |||
| break; | |||
| } | |||
| if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||
| SetControlFlowGroup(node, group_index); | |||
| break; | |||
| } | |||
| const auto &all_nodes = node->GetOutAllNodes(); | |||
| if (all_nodes.size() != kTransOpIoSize) { | |||
| break; | |||
| } | |||
| node = all_nodes.at(0); | |||
| } | |||
| } | |||
| } | |||
| @@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||
| peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | |||
| int64_t group_index = -1; | |||
| (void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| SetControlFlowGroup(stream_switch, group_index); | |||
| if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
| SetControlFlowGroup(stream_switch, group_index); | |||
| } | |||
| return stream_switch; | |||
| } | |||
| @@ -326,17 +326,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
| } | |||
| void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||
| if (node_item_->root_data_.count(input_idx) > 0) { | |||
| GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||
| root_tensor_values_[input_idx] = tensor; | |||
| const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) { | |||
| const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) { | |||
| return items.second.count(idx) > 0; | |||
| }; | |||
| return std::any_of(items.begin(), items.end(), is_exist); | |||
| }; | |||
| if (root_tensor_values_.count(input_idx) > 0) { | |||
| return; | |||
| } | |||
| if (node_item_->enter_data_.count(input_idx) > 0) { | |||
| if (is_persist_tensor(node_item_->root_data_, input_idx)) { | |||
| GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||
| root_tensor_values_[input_idx] = tensor; | |||
| } else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { | |||
| GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | |||
| root_tensor_values_[input_idx] = tensor; | |||
| } | |||
| } | |||
| void NodeState::UpdatePersistTensor() { | |||
| const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) { | |||
| for (const auto &item : items) { | |||
| for (const auto idx : item.second) { | |||
| UpdatePersistTensor(idx); | |||
| } | |||
| } | |||
| }; | |||
| if (root_tensor_values_.empty()) { | |||
| return; | |||
| } | |||
| update_tensor(node_item_->root_data_); | |||
| if (iteration_count_ > 0) { | |||
| update_tensor(node_item_->enter_data_); | |||
| } | |||
| } | |||
| void NodeState::UpdatePersistTensor(int input_idx) { | |||
| const auto it = root_tensor_values_.find(input_idx); | |||
| if (it == root_tensor_values_.end()) { | |||
| @@ -363,16 +391,9 @@ void NodeState::ResetContext(uint64_t iteration) { | |||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
| for (auto item : node_item_->root_data_) { | |||
| UpdatePersistTensor(item.first); | |||
| } | |||
| if (iteration > 0) { | |||
| data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | |||
| ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | |||
| for (auto item : node_item_->enter_data_) { | |||
| UpdatePersistTensor(item.first); | |||
| } | |||
| } | |||
| iteration_count_ = iteration; | |||
| @@ -132,6 +132,7 @@ struct NodeState { | |||
| void RunNextIteration(); | |||
| void SavePersistTensor(int input_idx, const TensorValue &tensor); | |||
| void UpdatePersistTensor(); | |||
| Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
| @@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | |||
| auto executor = node_item.node_executor; | |||
| GE_CHECK_NOTNULL(executor); | |||
| RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | |||
| node_state.UpdatePersistTensor(); | |||
| GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | |||
| node_state.GetName().c_str()); | |||
| RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | |||
| @@ -395,11 +395,13 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||
| data_send_.emplace(node_item); | |||
| node_item->data_recv_[this] = anchor_index; | |||
| if (is_root_node_) { | |||
| node_item->root_data_[anchor_index] = this; | |||
| auto &data_anchors = node_item->root_data_[this]; | |||
| data_anchors.emplace(anchor_index); | |||
| } | |||
| // If Enter feed Not Merge, take as root Node. | |||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||
| node_item->enter_data_[anchor_index] = this; | |||
| auto &data_anchors = node_item->enter_data_[this]; | |||
| data_anchors.emplace(anchor_index); | |||
| } | |||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
| } | |||
| @@ -148,9 +148,9 @@ struct NodeItem { | |||
| int64_t frame_index_ = -1; | |||
| int64_t parent_frame_ = -1; | |||
| std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | |||
| std::map<int, const NodeItem *> root_data_; // Recv data from root node | |||
| std::map<const NodeItem *, std::set<int>> root_data_; // Recv data from root node | |||
| std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | |||
| std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node | |||
| std::map<const NodeItem *, std::set<int>> enter_data_; // Recv data from Enter node | |||
| std::set<const NodeItem *> data_send_; // Send data notify to | |||
| std::map<const NodeItem *, int> data_recv_; // Recv data notify from | |||
| std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | |||
| @@ -460,10 +460,6 @@ Status TaskContext::PropagateOutputs() { | |||
| subgraph_context_->all_inputs_[input_offset].SetName( | |||
| node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | |||
| } | |||
| auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||
| GE_CHECK_NOTNULL(dst_node_state); | |||
| dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||
| } | |||
| } | |||
| (void)guard; | |||
| @@ -495,6 +491,7 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||
| void TaskContext::ReleaseInput(int index) { | |||
| auto input_tensor = MutableInput(index); | |||
| if (input_tensor != nullptr) { | |||
| node_state_->SavePersistTensor(index, *input_tensor); | |||
| input_tensor->Destroy(); | |||
| GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | |||
| } | |||
| @@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName) | |||
| INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | |||
| { | |||
| const char *env = getenv(name); | |||
| if (env != nullptr) { | |||
| strcpy(value, env); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -866,7 +866,6 @@ set(HYBRID_TEST_FILES | |||
| "hybrid/executor/hybrid_model_async_executor_unittest.cc" | |||
| "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | |||
| "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | |||
| ) | |||
| set(OTHERS_TEST_FILES | |||
| @@ -894,6 +893,7 @@ add_library(ge_ut_graph STATIC | |||
| target_compile_definitions(ge_ut_graph PRIVATE | |||
| google=ascend_private | |||
| FMK_SUPPORT_DUMP | |||
| ) | |||
| target_compile_options(ge_ut_graph PRIVATE | |||
| @@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||
| /// B --> C(AllReduce) --- D | |||
| /// / | |||
| /// stream id: 0 A | |||
| /// \ | |||
| /// \. | |||
| /// E --> F(AllReduce) --- G | |||
| /// stream id: 2 2 2 | |||
| /// | |||
| @@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) { | |||
| /// case of multi-output, then unuse stream | |||
| /// sub1 | |||
| /// / | \ | |||
| /// / | \. | |||
| /// sub2 sub3 sub4 | |||
| TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | |||
| SubGraphInfoPtr data = CreateDataSubgraph(); | |||
| @@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | |||
| /// if paralle id 1, then use stream | |||
| /// sub1 | |||
| /// / | | \ | |||
| /// / | | \. | |||
| /// sub2 sub3 sub4 sub5 | |||
| TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | |||
| SubGraphInfoPtr data = CreateDataSubgraph(); | |||
| @@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | |||
| /// if the param of engine independent is true, then set independent stream | |||
| /// sub1 | |||
| /// / | | \ | |||
| /// / | | \. | |||
| /// sub2 sub3 sub4 sub5 | |||
| TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||
| SubGraphInfoPtr data = CreateDataSubgraph(); | |||
| @@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||
| /// set stream based on stream label, and then based on independent | |||
| /// sub1 | |||
| /// / | | \ | |||
| /// / | | \. | |||
| /// sub2 sub3 sub4 sub5 | |||
| TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | |||
| SubGraphInfoPtr data = CreateDataSubgraph(); | |||
| @@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test { | |||
| /// | |||
| /// A | |||
| /// / \ | |||
| /// / \. | |||
| /// B C | |||
| /// | | | |||
| /// D 400 | |||
| @@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test { | |||
| }; | |||
| /// D E | |||
| /// | \ | \ | |||
| /// | \ | \. | |||
| /// F C G | |||
| /// : | : | |||
| /// H A I | |||
| @@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) { | |||
| EXPECT_EQ(graph->FindNode("D"), nullptr); | |||
| } | |||
| /// E F | |||
| /// | \ | \ | |||
| /// E F | |||
| /// | \ | \. | |||
| /// H C -> D G | |||
| /// \ | : | |||
| /// A I | |||
| @@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test { | |||
| /// reshape1 | |||
| /// | | |||
| /// add1 | |||
| /// / \ | |||
| /// / \. | |||
| /// | | | |||
| /// data1 const1 | |||
| ComputeGraphPtr BuildGraph1() { | |||
| @@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() { | |||
| } | |||
| /// sum1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| /// / \. | |||
| /// reshape1 addn1 | |||
| /// | c | | |||
| /// add1 <--- shape1 | |||
| @@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::str | |||
| /// Op1 | |||
| /// | | |||
| /// Merge | |||
| /// / \ | |||
| /// / \. | |||
| /// Op2 Op3 | |||
| TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | |||
| auto builder = ut::GraphBuilder("g1"); | |||
| @@ -245,7 +245,7 @@ TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | |||
| /// Op1 | |||
| /// | | |||
| /// Merge | |||
| /// / \ | |||
| /// / \. | |||
| /// Op2 Op3 | |||
| TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | |||
| auto builder = ut::GraphBuilder("g1"); | |||
| @@ -459,7 +459,7 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||
| /// data1 const | |||
| /// \ / | |||
| /// while | |||
| /// / \ | |||
| /// / \. | |||
| /// | | | |||
| /// cast1 cast2 | |||
| ComputeGraphPtr BuildWhileGraph1() { | |||
| @@ -34,11 +34,11 @@ namespace { | |||
| /// net_output | |||
| /// | | |||
| /// merge | |||
| /// / \ | |||
| /// / \. | |||
| /// square add | |||
| /// F| T/ T\ | |||
| /// F| T/ T\. | |||
| /// switch1 switch2 | |||
| /// / \ / \ | |||
| /// / \ / \. | |||
| /// var1 var2 var3 | |||
| /// | |||
| ComputeGraphPtr BuildGraph1() { | |||
| @@ -173,8 +173,8 @@ namespace { | |||
| /// shapeNo1 | |||
| /// | | |||
| /// addnYes1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| /// const1 const2 | |||
| ComputeGraphPtr BuildGraph1() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() { | |||
| /// shapeNo1 | |||
| /// | c | |||
| /// addnYes1 <----- dataNo1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| /// const1 const2 | |||
| ComputeGraphPtr BuildGraph3() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() { | |||
| /// shapeNo1 | |||
| /// | c | |||
| /// addnYes1 <--------- | |||
| /// / \ \ | |||
| /// / \ c \ | |||
| /// / \ \. | |||
| /// / \ c \. | |||
| /// const1 const2 <----- dataNo1 | |||
| ComputeGraphPtr BuildGraph4() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() { | |||
| /// shapeNo1 | |||
| /// | c | |||
| /// addnYes1 <----- dataNo1 | |||
| /// / \ | |||
| /// / \. | |||
| /// / \ c | |||
| /// const1 const2 <----- dataNo2 | |||
| ComputeGraphPtr BuildGraph5() { | |||
| @@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() { | |||
| /// addYes1 <---- const3 | |||
| /// | | |||
| /// addnYes1 <- | |||
| /// / \ \ | |||
| /// / \ \ | |||
| /// / \ \. | |||
| /// / \ \. | |||
| /// const1 const2 const4 | |||
| ComputeGraphPtr BuildGraph6() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() { | |||
| } | |||
| /// netoutput1 | |||
| /// / \ | |||
| /// / \. | |||
| /// shapeNo1 ShpaeNo2 | |||
| /// \ / | |||
| /// huberLoss1 | |||
| /// / | \ | |||
| /// / | \ | |||
| /// / | \. | |||
| /// / | \. | |||
| /// const1 const2 const3 | |||
| ComputeGraphPtr BuildGraph7() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() { | |||
| /// shapeNo1 | |||
| /// | | |||
| /// addnNo1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| /// const1 const2 | |||
| ComputeGraphPtr BuildGraph8() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() { | |||
| /// shapeNo1 | |||
| /// | | |||
| /// addnYes1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| /// const1 data1 | |||
| ComputeGraphPtr BuildGraph9() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() { | |||
| } | |||
| /// netoutput1 | |||
| /// / \ | |||
| /// / \. | |||
| /// addDim sqrt1 | |||
| /// \ / | |||
| /// switch1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| /// const1 const2 | |||
| ComputeGraphPtr BuildGraph10() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -63,8 +63,8 @@ namespace { | |||
| /// shapeNo1 | |||
| /// | | |||
| /// addnNo1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| /// const1 const2 | |||
| ComputeGraphPtr BuildGraph8() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() { | |||
| /// shapeNo1 | |||
| /// | | |||
| /// addnYes1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| ///const1 data1 | |||
| ComputeGraphPtr BuildGraph9() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| @@ -46,7 +46,7 @@ class UtestGraphPassesFoldingKernelSsdPriorboxKernel : public testing::Test { | |||
| /// convolution data | |||
| /// | / | |||
| /// ssdpriorbox | |||
| /// \ | |||
| /// \. | |||
| /// reshape | |||
| class NodeBuilder { | |||
| public: | |||
| @@ -120,7 +120,7 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { | |||
| /// graph with subgraph | |||
| /// const | |||
| /// / \ | |||
| /// / \. | |||
| /// cast1 cast1 | |||
| /// \ / | |||
| /// case | |||
| @@ -69,62 +69,100 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string | |||
| return graph.AddNode(op_desc); | |||
| } | |||
| static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||
| static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge, vector<NodePtr> &loop, vector<NodePtr> &cond) { | |||
| /******************************************************************************* | |||
| * Exit Identify | |||
| * \ / \. | |||
| * \ / \. | |||
| * Switch Add | |||
| * / | | | |||
| * / | | | |||
| * / | | | |||
| * LoopCond | | | |||
| * \ | | | |||
| * \ | | | |||
| * \ | | | |||
| * Less | | | |||
| * \ | NextIteration | |||
| * \ | | | |||
| * \ | | | |||
| * Merge <---------| | |||
| * | | |||
| * | | |||
| * Enter | |||
| * | | |||
| * +--------------------- Merge ----------------------+ | |||
| * / | | |||
| * / | | |||
| * / | | |||
| * / | | |||
| * Exit Identify | | |||
| * \ / \. | | |||
| * \ / \. | | |||
| * Switch Add Add | |||
| * / | | | | |||
| * / | | | | |||
| * / | | | | |||
| * LoopCond | | | | |||
| * \ | | | | |||
| * \ | | | | |||
| * \ | | | | |||
| * Less | | | | |||
| * \ | NextIteration | | |||
| * \ | | | | |||
| * \ | | | | |||
| * Merge <---------| | | |||
| * | | | |||
| * | | | |||
| * Enter | | |||
| * \ | | |||
| * \ | | |||
| * Switch Switch | |||
| * | | | |||
| * +-----------------Equal----------------------+ | |||
| * | | |||
| ******************************************************************************/ | |||
| auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||
| auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); | |||
| auto data2 = CreateNode(*graph, "data2", DATA, 1, 1); | |||
| auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1); | |||
| auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2); | |||
| auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2); | |||
| auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||
| auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||
| auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||
| auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2); | |||
| auto less1 = CreateNode(*graph, "less1", LESS, 2, 1); | |||
| auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | |||
| auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||
| auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2); | |||
| auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | |||
| auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||
| auto add1 = CreateNode(*graph, "add1", ADD, 2, 1); | |||
| auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | |||
| auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | |||
| auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||
| auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||
| auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1); | |||
| auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1); | |||
| auto add2 = CreateNode(*graph, "add2", ADD, 2, 1); | |||
| auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2); | |||
| auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); | |||
| cond.emplace_back(switch1); | |||
| cond.emplace_back(switch2); | |||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false | |||
| GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); | |||
| loop.emplace_back(merge1); | |||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false | |||
| GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true | |||
| loop.emplace_back(switch3); | |||
| GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||
| merge = merge1; | |||
| GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true | |||
| GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||
| cond.emplace_back(merge2); | |||
| merge = merge2; | |||
| } | |||
| static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||
| @@ -197,12 +235,24 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||
| TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | |||
| auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||
| NodePtr merge; | |||
| CreateLoopGraph(graph, merge); | |||
| AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||
| vector<NodePtr> loop; | |||
| vector<NodePtr> cond; | |||
| CreateLoopGraph(graph, merge, loop, cond); | |||
| MarkForceUnknownForCondPass mark_force_unknown_pass; | |||
| EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | |||
| EXPECT_EQ(loop.size(), 2); | |||
| for (const auto &node : loop) { | |||
| EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)); | |||
| } | |||
| EXPECT_EQ(cond.size(), 3); | |||
| for (const auto &node : cond) { | |||
| int64_t group_index = -1; | |||
| EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)); | |||
| EXPECT_EQ(group_index, merge->GetOpDesc()->GetId()); | |||
| } | |||
| } | |||
| TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | |||
| @@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) { | |||
| } | |||
| /// Merge | |||
| /// | \ | |||
| /// | \ | |||
| /// | \. | |||
| /// | \. | |||
| /// Op1 Op2 Merge2 | |||
| /// \ | | | |||
| /// \ | Op3 | |||
| @@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da | |||
| } | |||
| /// Merge | |||
| /// | \ | |||
| /// | \ | |||
| /// | \. | |||
| /// | \. | |||
| /// Op1 Op2 Merge2 | |||
| /// \ | | \ | |||
| /// \ | | \. | |||
| /// \ | Op3 | |||
| /// \ | : | |||
| /// NetOutput | |||
| @@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co | |||
| TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | |||
| /// Merge | |||
| /// | \ | |||
| /// | \ | |||
| /// | \. | |||
| /// | \. | |||
| /// Op1 Op2 Merge2 | |||
| /// \ | | | |||
| /// \ | Op3 | |||
| @@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | |||
| /// Op1 Op2 Merge2 | |||
| /// \ | | |||
| /// \ Op3 | |||
| /// \ | |||
| /// \. | |||
| /// Merge3 | |||
| ret = pass_.Run(merge_node2); | |||
| @@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) { | |||
| /// Op1 | |||
| /// | | |||
| /// Merge | |||
| /// / \ | |||
| /// / \. | |||
| /// Op2 Op3 | |||
| auto merge_node = NewNode("Merge", MERGE, 1, 2); | |||
| auto node1 = NewNode("Op1", RELU, 1, 1); | |||
| @@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) { | |||
| /// Const | |||
| /// | | |||
| /// Merge Pass Const | |||
| /// / \ ===> / \ | |||
| /// / \ ===> / \. | |||
| /// Op1 Op2 Op1 Op2 | |||
| auto merge_node = NewNode("Merge", MERGE, 1, 2); | |||
| auto const_node = NewNode("Const", CONSTANT, 1, 1); | |||
| @@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes) | |||
| /// / | ===> / \(control anchor) | |||
| /// Op1 | \ Op1 Constant | |||
| /// Op2 Op3 | | |||
| /// / \ | |||
| /// / \. | |||
| /// Op2 Op3 | |||
| auto merge_node = NewNode("Merge", MERGE, 1, 2); | |||
| auto const_node = NewNode("Const", CONSTANT, 1, 1); | |||
| @@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1) | |||
| /// / | ===> / \(control anchor) | |||
| /// Op1 | \ Op1 Constant | |||
| /// Op2 Op3 | | |||
| /// / \ | |||
| /// / \. | |||
| /// Op2 Op3 | |||
| auto merge_node = NewNode("Merge", MERGE, 1, 2); | |||
| auto const_node = NewNode("Const", CONSTANT, 1, 1); | |||
| @@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||
| /// C | |||
| /// | | |||
| /// Merge | |||
| /// / \ | |||
| /// / \. | |||
| /// Op1 Op2 | |||
| auto switch_node = NewNode("Switch", SWITCH, 1, 2); | |||
| auto identity_node = NewNode("Identity", SWITCH, 1, 1); | |||
| @@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||
| /// . | |||
| /// . | |||
| /// C | |||
| /// / \ | |||
| /// / \. | |||
| /// Op1 Op2 | |||
| auto ret = pass_.Run(merge_node); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| @@ -66,11 +66,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||
| void BuildDefaultGraph() { | |||
| /// input | |||
| /// \ | |||
| /// \. | |||
| /// sqrt pred | |||
| /// \ / | |||
| /// cast | |||
| /// / \ | |||
| /// / \. | |||
| /// switch_t switch_f | |||
| /// | | | |||
| /// F T | |||
| @@ -118,13 +118,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||
| void BuildDefaultGraph1() { | |||
| /// input | |||
| /// \ | |||
| /// \. | |||
| /// sqrt pred | |||
| /// \ / | |||
| /// Switch | |||
| /// | | | |||
| /// ----F T---- | |||
| /// \ | / \ | |||
| /// \ | / \. | |||
| /// \ Merge1 Merge2 | |||
| /// \_________| | |||
| input_node_ = NewNode("input", RELU, 0, 1); | |||
| @@ -164,14 +164,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||
| void BuildDefaultGraph2() { | |||
| /// input input1 | |||
| /// \ \ | |||
| /// \ \. | |||
| /// sqrt pred sqrt1 pred1 | |||
| /// \ / \ / | |||
| /// Switch Switch1 | |||
| /// | | _______| | |||
| /// | | / | |||
| /// ____F T____ | |||
| /// \ | / \ | |||
| /// \ | / \. | |||
| /// \ Merge1 Merge2 | |||
| /// \__________| | |||
| input_node_ = NewNode("input", RELU, 0, 2); | |||
| @@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test { | |||
| namespace { | |||
| /// netoutput1 | |||
| /// | \ | |||
| ///transdata1 \ | |||
| /// | \ | |||
| /// | \. | |||
| ///transdata1 \. | |||
| /// | \. | |||
| /// | transdata2 | |||
| /// | / | |||
| /// var1 const1 | |||
| @@ -35,7 +35,7 @@ namespace { | |||
| /// transdata1 | |||
| /// | | |||
| /// reshape1 | |||
| /// | \ | |||
| /// | \. | |||
| /// var1 const1 | |||
| ut::GraphBuilder Graph1Builder() { | |||
| ut::GraphBuilder builder = ut::GraphBuilder("g1"); | |||
| @@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() { | |||
| } | |||
| /// netoutput1 | |||
| /// | \ | |||
| ///transdata1 \ | |||
| /// | \ | |||
| /// | \. | |||
| ///transdata1 \. | |||
| /// | \. | |||
| /// reshape1 reshape2 | |||
| /// | \ / \ | |||
| /// | \ / \. | |||
| /// var1 const1 var2 | |||
| ut::GraphBuilder Graph2Builder() { | |||
| ut::GraphBuilder builder = ut::GraphBuilder("g2"); | |||
| @@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() { | |||
| } | |||
| /// netoutput1 | |||
| /// | \ | |||
| ///transdata1 \ | |||
| /// | \ | |||
| /// | \. | |||
| ///transdata1 \. | |||
| /// | \. | |||
| /// reshape1 transdata2 | |||
| /// | \ / | |||
| /// var1 const1 | |||
| @@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test { | |||
| namespace { | |||
| /// netoutput1 | |||
| /// | \ | |||
| /// | \. | |||
| /// StackPush StackPop | |||
| /// | | | |||
| /// var1 const1 | |||
| @@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() { | |||
| /// netoutput1 | |||
| /// | | |||
| /// merge1 | |||
| /// / \ | |||
| /// / \. | |||
| /// / add1 | |||
| /// / F| \ | |||
| /// / F| \. | |||
| /// addn1 swtich2 var3 | |||
| /// \F T/ | | |||
| /// switch1 | | |||
| @@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() { | |||
| /// add1 | |||
| /// / \T | |||
| /// var3 swtich2 | |||
| /// T/ \ | |||
| /// switch1 \ | |||
| /// / \ \ | |||
| /// T/ \. | |||
| /// switch1 \. | |||
| /// / \ \. | |||
| /// var1 var2 var4 | |||
| ComputeGraphPtr BuildGraph3() { | |||
| auto builder = ut::GraphBuilder("g3"); | |||
| @@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() { | |||
| /// netoutput1 | |||
| /// | | |||
| /// merge1 | |||
| /// / \ | |||
| /// / \. | |||
| /// add1 addn1 | |||
| /// / \T F/ | |||
| /// var3 swtich2 | |||
| @@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) { | |||
| } | |||
| /// ----> netoutput1 | |||
| /// / | \ | |||
| /// / | \. | |||
| /// transdata1 transdata2 transdata3 | |||
| /// \ / | | |||
| /// var1-------------- | |||
| @@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() { | |||
| } | |||
| /// ---------> netoutput1 | |||
| /// / | \ | |||
| /// / | \. | |||
| /// transdata1 transdata2(l1) transdata3(l1) | |||
| /// \ / | | |||
| /// var1------------------ | |||
| @@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||
| /// -->transpose1 -->transpose3-->sinh2 | |||
| /// | \ / | |||
| /// | -->transpose2 | |||
| /// | \ | |||
| /// | \. | |||
| /// / -->cast3-->cast4-->sinh3 | |||
| /// / | |||
| /// / -->transpose4-->transpose5-->sinh4 | |||
| /// / / | |||
| /// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 | |||
| /// \ \ | |||
| /// \ \. | |||
| /// \ -->sinh6 | |||
| /// \ | |||
| /// \. | |||
| /// \ -->transpose6-->transpose7-->sinh9 | |||
| /// \ / | |||
| /// -->reshape-->cast6-->cast7-->sinh8 | |||
| /// \ | |||
| /// \. | |||
| /// -->sinh7 | |||
| /// after optimized graph | |||
| @@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||
| /// / /-->transpose3-->sinh2 | |||
| /// -->Cast1 | |||
| /// / \-->sinh7 | |||
| /// / \ | |||
| /// / \. | |||
| /// / -->sinh9 | |||
| /// Node4D | |||
| /// \ -->sinh4 | |||
| /// \ / | |||
| /// -->Cast5-->sinh5 | |||
| /// \ \ | |||
| /// \ \. | |||
| /// \ -->sinh6 | |||
| /// \ | |||
| /// \. | |||
| /// -->Cast7-->sinh8 | |||
| ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| @@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran | |||
| /// TransData TransData ... MatMul ... | |||
| /// \ | / / / | |||
| /// HcomAllReduce | |||
| /// / | \ \ \ | |||
| /// / | \ \ \. | |||
| /// TransData TransData ... RealDiv ... | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| NodePtr allreduce = | |||
| @@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans | |||
| /// TransData TransData ... MatMul ... | |||
| /// \ | / / / | |||
| /// HcomAllReduce | |||
| /// / | \ \ \ | |||
| /// / | \ \ \. | |||
| /// TransData TransData ... RealDiv ... | |||
| size_t symmetric_transdata_num = 20; | |||
| size_t asymmetric_transdata_num = 20; | |||
| @@ -66,7 +66,7 @@ namespace { | |||
| /// transdata2 | |||
| /// | | |||
| /// assign1 | |||
| /// / \ | |||
| /// / \. | |||
| /// transdata1 | | |||
| /// | | | |||
| /// var1 const1 | |||
| @@ -35,8 +35,8 @@ namespace { | |||
| /// shapeNo1 | |||
| /// | | |||
| /// addnYes1 | |||
| /// / \ | |||
| /// / \ | |||
| /// / \. | |||
| /// / \. | |||
| /// const1 const2 | |||
| ComputeGraphPtr BuildGraph1() { | |||
| @@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() { | |||
| /// | |||
| /// netoutput1 | |||
| /// / \ \ | |||
| /// add1 assign1 \ | |||
| /// / \ / \ \ | |||
| /// / \ \. | |||
| /// add1 assign1 \. | |||
| /// / \ / \ \. | |||
| /// var1 var2 const1 var3 | |||
| ComputeGraphPtr BuildGraph2() { | |||