From 242afc4e6799a8910328805e8774c83f84e3ef9c Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 13 Mar 2021 17:30:39 +0800 Subject: [PATCH 1/6] modified: ge/graph/passes/base_pass.cc modified: ge/graph/passes/base_pass.h modified: ge/graph/passes/infershape_pass.cc --- ge/graph/passes/base_pass.cc | 43 ++++++++++++++++++++++-------- ge/graph/passes/base_pass.h | 11 ++++++++ ge/graph/passes/infershape_pass.cc | 16 +++++++++++ 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 3b854c18..64342509 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -31,7 +31,7 @@ constexpr size_t kMaxOneInNodes = 1000; // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later constexpr int kMaxRecursiveDepth = 20; -void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &input_edge_nodes, +void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { nodes_last.clear(); for (auto &node : graph->GetDirectNode()) { @@ -40,7 +40,7 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &i } size_t in_nums = node->GetInNodes().size(); if (in_nums == 0) { - input_edge_nodes.push(node); + input_edge_nodes.push_back(node); nodes_seen.insert(node.get()); } else if (in_nums > kMaxOneInNodes) { nodes_last.insert(node); @@ -48,7 +48,7 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &i } } -void AddNextIterNodes(const Node::Vistor &nodes, std::queue &nodes_to_pass, +void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { for (auto &node : nodes) { if (node == nullptr) { @@ -60,13 +60,14 @@ void AddNextIterNodes(const Node::Vistor &nodes, std::queue &n bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { - nodes_to_pass.push(node); + nodes_to_pass.push_back(node); } } } Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unordered_set &nodes_re_pass, - std::unordered_set &nodes_deleted, std::unordered_set &nodes_seen) { + std::unordered_set &nodes_re_pass_immediately, std::unordered_set &nodes_deleted, + std::unordered_set &nodes_seen) { if (node == nullptr) { GELOGE(FAILED, "parameter is null."); return FAILED; @@ -104,6 +105,21 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder } } + auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); + for (const auto &node_to_re_pass : nodes_to_re_pass_immediately) { + if (node_to_re_pass == nullptr) { + GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { + GELOGD("The node %s will be re-pass immediately.", node_to_re_pass->GetName().c_str()); + nodes_re_pass_immediately.insert(node_to_re_pass); + } else { + GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); + } + } + auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); if (nodes_deleted_by_pass.count(node) > 0) { @@ -181,10 +197,11 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); - std::queue nodes; + std::deque nodes; std::unordered_set nodes_seen; std::unordered_set nodes_deleted; std::unordered_set nodes_re_pass; + std::unordered_set nodes_re_pass_immediately; std::unordered_set nodes_last; GetAllNodesNoInputEdge(graph_, nodes, nodes_seen, nodes_last); GELOGD("Start points count %zu", nodes.size()); @@ -192,14 +209,14 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { do { for (auto &node : nodes_re_pass) { - nodes.push(node); + nodes.push_back(node); nodes_seen.insert(node.get()); } nodes_re_pass.clear(); while (!nodes.empty()) { NodePtr node = nodes.front(); - nodes.pop(); + nodes.pop_front(); (void)nodes_re_pass.erase(node); GE_IF_BOOL_EXEC(node == nullptr, GELOGW("node is null"); continue); @@ -210,7 +227,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { AddNextIterNodes(node->GetOutNodes(), nodes, nodes_seen, nodes_last); - auto ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_deleted, nodes_seen); + auto ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_re_pass_immediately, nodes_deleted, nodes_seen); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", node->GetName().c_str(), node->GetType().c_str(), ret); @@ -227,7 +244,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { if (has_sub_graph) { GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); - ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_deleted, nodes_seen); + ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_re_pass_immediately, nodes_deleted, nodes_seen); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", node->GetName().c_str(), node->GetType().c_str(), ret); @@ -239,12 +256,16 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { // should be called each time at the begin of the iteration ClearOption(names_to_passes); } + for(auto &node : nodes_re_pass_immediately){ + nodes.push_front(node); + } + nodes_re_pass_immediately.clear(); } for (auto &node : nodes_last) { bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { - nodes.push(node); + nodes.push_back(node); } } nodes_last.clear(); diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index bb41691d..89a364a9 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -53,6 +53,8 @@ class BaseNodePass { std::unordered_set GetNodesNeedRePass() { return nodes_need_re_pass_; } + std::unordered_set GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } + std::unordered_set GetNodesDeleted() { return nodes_deleted_; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } @@ -79,6 +81,14 @@ class BaseNodePass { /// void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } + /// + /// Add a node to be optimized immediately again. If you add a new node to the graph, or + /// change a node connections, and you want to make sure the node will be + /// optimized by other passes, call this function. + /// @param node + /// + void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } + /// /// Add a node and it's input/output data nodes to be optimized again. /// @param node @@ -109,6 +119,7 @@ class BaseNodePass { private: std::unordered_set nodes_need_re_pass_; + std::unordered_set nodes_need_re_pass_immediately_; std::unordered_set nodes_deleted_; std::map options_; }; diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 7b8f7b50..fd943c2d 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -25,6 +25,7 @@ namespace ge { Status InferShapePass::Run(NodePtr &node) { + // kOptimizeAfterSubGraph exist means after subgraph auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); if (ret != GRAPH_SUCCESS) { // select INFERSHAPE failed info @@ -41,6 +42,21 @@ Status InferShapePass::Run(NodePtr &node) { GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } + if(node->GetType() == WHILE){ + bool need_repass = false; + AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); + if(!OptionExists(kOptimizeAfterSubGraph)){ + return SUCCESS; + } + if(need_repass){ + AddImmediateRePassNode(node); + GELOGD("Node %s need repass immediately.", node->GetName().c_str()); + } + else{ + // clear attr on while + node->GetOpDesc()->DelAttr("need_infer_again_"); + } + } return SUCCESS; } } // namespace ge From c067e32c68ff801bb775a9b13f19a0f3c95acab6 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 13 Mar 2021 20:22:01 +0800 Subject: [PATCH 2/6] modified: ge/graph/passes/base_pass.h --- ge/graph/passes/base_pass.h | 1 + 1 file changed, 1 insertion(+) diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index 89a364a9..a9f4f000 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -64,6 +64,7 @@ class BaseNodePass { void init() { nodes_need_re_pass_.clear(); nodes_deleted_.clear(); + nodes_need_re_pass_immediately_.clear(); } protected: From f203c70cfdaefd8f9b750ed713dcc746e80598d4 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Mon, 15 Mar 2021 14:59:46 +0800 Subject: [PATCH 3/6] modified: tests/ut/ge/graph/passes/base_pass_unittest.cc --- .../ut/ge/graph/passes/base_pass_unittest.cc | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index 56a7077a..b1934359 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -67,6 +67,21 @@ class UtestTestPass : public BaseNodePass { names_to_add_repass_.erase(iter); } } + // simulate infershape pass + if(node->GetType() == WHILE){ + bool need_repass = false; + AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); + if(!OptionExists(kOptimizeAfterSubGraph)){ + return SUCCESS; + } + if(need_repass){ + AddImmediateRePassNode(node); + } + else{ + // clear attr on while + node->GetOpDesc()->DelAttr("need_infer_again_"); + } + } return SUCCESS; } void clear() { iter_nodes_.clear(); } @@ -429,6 +444,7 @@ TEST_F(UTESTGraphPassesBasePass, dead_loop) { EXPECT_EQ(test_pass.GetRunTimes(), 1007); } */ + TEST_F(UTESTGraphPassesBasePass, while_loop) { NamesToPass names_to_pass; auto test_pass = UtestTestPass(true); @@ -438,4 +454,69 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { auto ge_pass = GEPass(graph); EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); } + +/// data1 const +/// \ / +/// while +/// / \ +/// | | +/// cast1 cast2 +ComputeGraphPtr BuildWhileGraph1() { + // build sub graph + auto builder_sub = ut::GraphBuilder("sub"); + auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); + auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); + auto add = builder_sub.AddNode("add", ADD, 2, 1); + + builder_sub.AddDataEdge(data_1, 0, add, 0); + builder_sub.AddDataEdge(data_2, 0, add, 1); + auto sub_graph = builder_sub.GetGraph(); + sub_graph->SetName("while_sub"); + // build root graph + auto builder = ut::GraphBuilder("g1"); + auto data = builder.AddNode("data1", DATA, 0, 1); + auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); + auto c1 = builder.AddNode("cast1", CAST, 1, 1); + auto c2 = builder.AddNode("cast2", CAST, 1, 1); + // add while op + auto tensor_desc = std::make_shared(); + tensor_desc->SetShape(GeShape({1,1,1,1})); + tensor_desc->SetFormat(FORMAT_ND); + tensor_desc->SetDataType(DT_INT32); + + auto op_desc = std::make_shared("while", WHILE); + for (int i = 0; i < 2; ++i) { + op_desc->AddInputDesc(tensor_desc->Clone()); + } + for (int i = 0; i < 2; ++i) { + op_desc->AddOutputDesc(tensor_desc->Clone()); + } + AttrUtils::SetBool(op_desc,"need_infer_again_", true); + op_desc->AddSubgraphName(sub_graph->GetName()); + op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); + auto root_graph = builder.GetGraph(); + auto while_op = root_graph->AddNode(op_desc); + + builder.AddDataEdge(data, 0, while_op, 0); + builder.AddDataEdge(const_op, 0, while_op, 1); + builder.AddDataEdge(while_op, 0, c1, 0); + builder.AddDataEdge(while_op, 1, c2, 0); + sub_graph->SetParentGraph(root_graph); + sub_graph->SetParentNode(while_op); + root_graph->AddSubgraph(sub_graph); + return root_graph; +} + +TEST_F(UTESTGraphPassesBasePass, while_infershape) { +NamesToPass names_to_pass; +auto test_pass = UtestTestPass(); +names_to_pass.push_back(std::make_pair("test", &test_pass)); + +auto graph = BuildWhileGraph1(); +auto ge_pass = GEPass(graph); +auto while_node = graph->FindNode("while"); +EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); +EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +} + } // namespace ge From 2572bed425d2ca84cb6515eeb1c2ccd724e37982 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Mon, 15 Mar 2021 15:37:06 +0800 Subject: [PATCH 4/6] modified: tests/ut/ge/graph/passes/base_pass_unittest.cc --- tests/ut/ge/graph/passes/base_pass_unittest.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index b1934359..129c11d8 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -75,6 +75,7 @@ class UtestTestPass : public BaseNodePass { return SUCCESS; } if(need_repass){ + AttrUtils::SetBool(node->GetOpDesc(),"need_infer_again_", false); AddImmediateRePassNode(node); } else{ From c293465b6cff8b3618772dfe175a843c50089d0c Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Mon, 15 Mar 2021 20:17:52 +0800 Subject: [PATCH 5/6] modified: ge/graph/passes/base_pass.cc modified: ge/graph/passes/infershape_pass.cc --- ge/graph/passes/base_pass.cc | 100 ++++++++++++++--------------- ge/graph/passes/infershape_pass.cc | 11 ++-- 2 files changed, 54 insertions(+), 57 deletions(-) diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 64342509..0868b729 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -30,6 +30,13 @@ constexpr int kMaxRePassTimes = 10000; constexpr size_t kMaxOneInNodes = 1000; // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later constexpr int kMaxRecursiveDepth = 20; +struct DuringPassNodeSets { + std::unordered_set nodes_seen; + std::unordered_set nodes_deleted; + std::unordered_set nodes_re_pass; + std::unordered_set nodes_re_pass_immediately; + std::unordered_set nodes_last; +}; void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { @@ -65,9 +72,25 @@ void AddNextIterNodes(const Node::Vistor &nodes, std::deque &n } } -Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unordered_set &nodes_re_pass, - std::unordered_set &nodes_re_pass_immediately, std::unordered_set &nodes_deleted, - std::unordered_set &nodes_seen) { +void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, + std::unordered_set &nodes_seen, std::unordered_set &nodes_to_re_pass, + std::unordered_set &nodes_re_pass) { + for (const auto &node_to_re_pass : nodes_to_re_pass) { + if (node_to_re_pass == nullptr) { + GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { + GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str()); + nodes_re_pass.insert(node_to_re_pass); + } else { + GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); + } + } +} + +Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) { if (node == nullptr) { GELOGE(FAILED, "parameter is null."); return FAILED; @@ -91,37 +114,15 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder } auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); - for (const auto &node_to_re_pass : nodes_to_re_pass) { - if (node_to_re_pass == nullptr) { - GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - continue; - } - if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { - GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); - nodes_re_pass.insert(node_to_re_pass); - } else { - GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); - } - } + PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, + during_pass_node_set.nodes_re_pass); auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); - for (const auto &node_to_re_pass : nodes_to_re_pass_immediately) { - if (node_to_re_pass == nullptr) { - GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - continue; - } - if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { - GELOGD("The node %s will be re-pass immediately.", node_to_re_pass->GetName().c_str()); - nodes_re_pass_immediately.insert(node_to_re_pass); - } else { - GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); - } - } + PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, + during_pass_node_set.nodes_re_pass_immediately); auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); - nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); + during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); if (nodes_deleted_by_pass.count(node) > 0) { GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), name_to_pass.first.c_str()); @@ -198,36 +199,32 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); std::deque nodes; - std::unordered_set nodes_seen; - std::unordered_set nodes_deleted; - std::unordered_set nodes_re_pass; - std::unordered_set nodes_re_pass_immediately; - std::unordered_set nodes_last; - GetAllNodesNoInputEdge(graph_, nodes, nodes_seen, nodes_last); + DuringPassNodeSets during_pass_node_set; + GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); GELOGD("Start points count %zu", nodes.size()); int re_pass_times = 0; do { - for (auto &node : nodes_re_pass) { + for (auto &node : during_pass_node_set.nodes_re_pass) { nodes.push_back(node); - nodes_seen.insert(node.get()); + during_pass_node_set.nodes_seen.insert(node.get()); } - nodes_re_pass.clear(); + during_pass_node_set.nodes_re_pass.clear(); while (!nodes.empty()) { NodePtr node = nodes.front(); nodes.pop_front(); - (void)nodes_re_pass.erase(node); + (void)during_pass_node_set.nodes_re_pass.erase(node); GE_IF_BOOL_EXEC(node == nullptr, GELOGW("node is null"); continue); - if (nodes_deleted.count(node) > 0) { + if (during_pass_node_set.nodes_deleted.count(node) > 0) { GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); continue; } - AddNextIterNodes(node->GetOutNodes(), nodes, nodes_seen, nodes_last); + AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); - auto ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_re_pass_immediately, nodes_deleted, nodes_seen); + auto ret = RunPasses(node, names_to_passes, during_pass_node_set); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", node->GetName().c_str(), node->GetType().c_str(), ret); @@ -244,7 +241,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { if (has_sub_graph) { GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); - ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_re_pass_immediately, nodes_deleted, nodes_seen); + ret = RunPasses(node, names_to_passes, during_pass_node_set); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", node->GetName().c_str(), node->GetType().c_str(), ret); @@ -256,20 +253,21 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { // should be called each time at the begin of the iteration ClearOption(names_to_passes); } - for(auto &node : nodes_re_pass_immediately){ + for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { + GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); nodes.push_front(node); } - nodes_re_pass_immediately.clear(); + during_pass_node_set.nodes_re_pass_immediately.clear(); } - for (auto &node : nodes_last) { - bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); - if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { + for (auto &node : during_pass_node_set.nodes_last) { + bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); + if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { nodes.push_back(node); } } - nodes_last.clear(); - } while ((!nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); + during_pass_node_set.nodes_last.clear(); + } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); if (re_pass_times == kMaxRePassTimes) { GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index fd943c2d..fb18204c 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -42,17 +42,16 @@ Status InferShapePass::Run(NodePtr &node) { GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } - if(node->GetType() == WHILE){ + if (node->GetType() == WHILE) { bool need_repass = false; - AttrUtils::GetBool(node->GetOpDesc(),"need_infer_again_", need_repass); - if(!OptionExists(kOptimizeAfterSubGraph)){ + AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass); + if (!OptionExists(kOptimizeAfterSubGraph)) { return SUCCESS; } - if(need_repass){ + if (need_repass) { AddImmediateRePassNode(node); GELOGD("Node %s need repass immediately.", node->GetName().c_str()); - } - else{ + } else { // clear attr on while node->GetOpDesc()->DelAttr("need_infer_again_"); } From 848236b21c0f492f141354f852d0ac23d8b45d63 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Mon, 15 Mar 2021 20:47:19 +0800 Subject: [PATCH 6/6] modified: ge/graph/passes/infershape_pass.cc --- ge/graph/passes/infershape_pass.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index fb18204c..a54a15c1 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -42,9 +42,9 @@ Status InferShapePass::Run(NodePtr &node) { GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } - if (node->GetType() == WHILE) { - bool need_repass = false; - AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass); + bool need_repass = false; + auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass); + if (has_attr) { if (!OptionExists(kOptimizeAfterSubGraph)) { return SUCCESS; }