Use OrderedSet for order list to optimize performance.tags/v1.2.0-rc1
| @@ -4,6 +4,9 @@ mindspore/lib | |||||
| output | output | ||||
| *.ir | *.ir | ||||
| st_tests | st_tests | ||||
| kernel_meta/ | |||||
| somas_meta/ | |||||
| trace_code_graph_* | |||||
| # mindspore lite java | # mindspore lite java | ||||
| mindspore/lite/java/java/.gradle | mindspore/lite/java/java/.gradle | ||||
| @@ -323,7 +323,7 @@ class SideEffectFinder { | |||||
| } | } | ||||
| static void UpdateOrderList(const FuncGraphPtr &func_graph) { | static void UpdateOrderList(const FuncGraphPtr &func_graph) { | ||||
| std::list<CNodePtr> new_order_list; | |||||
| OrderedSet<CNodePtr> new_order_list; | |||||
| const auto &order_list = func_graph->order_list(); | const auto &order_list = func_graph->order_list(); | ||||
| for (auto &cnode : order_list) { | for (auto &cnode : order_list) { | ||||
| PushToOrderList(func_graph, cnode, &new_order_list); | PushToOrderList(func_graph, cnode, &new_order_list); | ||||
| @@ -331,10 +331,9 @@ class SideEffectFinder { | |||||
| func_graph->set_order_list(std::move(new_order_list)); | func_graph->set_order_list(std::move(new_order_list)); | ||||
| } | } | ||||
| static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, std::list<CNodePtr> *new_order_list) { | |||||
| static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, OrderedSet<CNodePtr> *new_order_list) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto iter = std::find(new_order_list->begin(), new_order_list->end(), cnode); | |||||
| if (iter != new_order_list->end()) { | |||||
| if (new_order_list->contains(cnode)) { | |||||
| return; | return; | ||||
| } | } | ||||
| for (auto &input : cnode->inputs()) { | for (auto &input : cnode->inputs()) { | ||||
| @@ -136,23 +136,21 @@ CNodePtr FuncGraph::NewCNodeInFront(const std::vector<AnfNodePtr> &inputs) { | |||||
| CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) { | CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) { | ||||
| CNodePtr cnode = NewCNode(inputs); | CNodePtr cnode = NewCNode(inputs); | ||||
| auto iter = std::find(order_.begin(), order_.end(), position); | |||||
| CNodePtr pos_cnode = dyn_cast<CNode>(position); | |||||
| auto iter = order_.find(pos_cnode); | |||||
| order_.insert(iter, cnode); | order_.insert(iter, cnode); | ||||
| return cnode; | return cnode; | ||||
| } | } | ||||
| CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) { | CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) { | ||||
| CNodePtr cnode = NewCNode(inputs); | CNodePtr cnode = NewCNode(inputs); | ||||
| if (!position->isa<CNode>()) { | |||||
| order_.push_front(cnode); | |||||
| return cnode; | |||||
| } | |||||
| auto iter = std::find(order_.begin(), order_.end(), position); | |||||
| CNodePtr pos_cnode = dyn_cast<CNode>(position); | |||||
| auto iter = order_.find(pos_cnode); | |||||
| if (iter == order_.end()) { | if (iter == order_.end()) { | ||||
| order_.push_front(cnode); | order_.push_front(cnode); | ||||
| return cnode; | |||||
| } else { | |||||
| order_.insert(std::next(iter), cnode); | |||||
| } | } | ||||
| order_.insert(std::next(iter), cnode); | |||||
| return cnode; | return cnode; | ||||
| } | } | ||||
| @@ -616,7 +614,7 @@ void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &node) { | |||||
| if (node) { | if (node) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (cnode) { | if (cnode) { | ||||
| order_.remove(cnode); | |||||
| order_.erase(cnode); | |||||
| MS_LOG(DEBUG) << "Remove the node" << node->DebugString() << " from order list."; | MS_LOG(DEBUG) << "Remove the node" << node->DebugString() << " from order list."; | ||||
| } | } | ||||
| } | } | ||||
| @@ -636,7 +634,7 @@ void FuncGraph::ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new | |||||
| return; | return; | ||||
| } | } | ||||
| // Search old node in order list. | // Search old node in order list. | ||||
| auto iter = std::find(order_.begin(), order_.end(), old_cnode); | |||||
| auto iter = order_.find(old_cnode); | |||||
| if (iter == order_.end()) { | if (iter == order_.end()) { | ||||
| // Skip if old node not found in order list. | // Skip if old node not found in order list. | ||||
| return; | return; | ||||
| @@ -359,15 +359,15 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||||
| void EraseUnusedNodeInOrder(const AnfNodePtr &n); | void EraseUnusedNodeInOrder(const AnfNodePtr &n); | ||||
| void EraseUnusedNodeInOrder(); | void EraseUnusedNodeInOrder(); | ||||
| void DumpCNodeList(); | void DumpCNodeList(); | ||||
| const std::list<CNodePtr> &order_list() const { return order_; } | |||||
| const OrderedSet<CNodePtr> &order_list() const { return order_; } | |||||
| void set_order_list(std::list<CNodePtr> &&order_list) { order_ = std::move(order_list); } | |||||
| void set_order_list(OrderedSet<CNodePtr> &&order_list) { order_ = std::move(order_list); } | |||||
| // Add a cnode at the end of order list. | // Add a cnode at the end of order list. | ||||
| void AppendOrderList(const CNodePtr &cnode) { order_.push_back(cnode); } | void AppendOrderList(const CNodePtr &cnode) { order_.push_back(cnode); } | ||||
| // Prepend cnode at the front of order list. | // Prepend cnode at the front of order list. | ||||
| void PrependOrderList(const CNodePtr &cnode) { order_.insert(order_.begin(), cnode); } | |||||
| void PrependOrderList(const CNodePtr &cnode) { order_.push_front(cnode); } | |||||
| // Maintain cnode order list when a cnode is replaced by a new one. | // Maintain cnode order list when a cnode is replaced by a new one. | ||||
| void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | ||||
| @@ -461,7 +461,7 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||||
| const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes); | const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes); | ||||
| // CNode order which relates to origin code order | // CNode order which relates to origin code order | ||||
| std::list<CNodePtr> order_; | |||||
| OrderedSet<CNodePtr> order_; | |||||
| bool stub_; | bool stub_; | ||||
| inline static Drawer drawer_ = nullptr; | inline static Drawer drawer_ = nullptr; | ||||
| // Design switch_layer_input as a ptr to | // Design switch_layer_input as a ptr to | ||||
| @@ -58,6 +58,8 @@ class OrderedSet { | |||||
| } | } | ||||
| } | } | ||||
| OrderedSet(OrderedSet &&os) = default; | |||||
| explicit OrderedSet(const sequential_type &other) { | explicit OrderedSet(const sequential_type &other) { | ||||
| for (auto &item : other) { | for (auto &item : other) { | ||||
| add(item); | add(item); | ||||
| @@ -80,23 +82,27 @@ class OrderedSet { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Add an element to the OrderedSet, without judging return value | |||||
| void add(const element_type &e) { (void)insert(e); } | |||||
| // insert an element to the OrderedSet | |||||
| std::pair<iterator, bool> insert(const element_type &e) { | |||||
| iterator empty_itr; | |||||
| std::pair<element_type, typename map_type::mapped_type> map_pair = std::make_pair(e, empty_itr); | |||||
| auto result = mapped_data_.insert(map_pair); | |||||
| auto &seq_idx = result.first->second; | |||||
| // if insert success; | |||||
| OrderedSet &operator=(OrderedSet &&os) = default; | |||||
| // insert an element to the OrderedSet after the given position. | |||||
| std::pair<iterator, bool> insert(iterator pos, const element_type &e) { | |||||
| auto result = mapped_data_.emplace(e, ordered_data_.end()); | |||||
| if (result.second) { | if (result.second) { | ||||
| auto it = ordered_data_.insert(ordered_data_.end(), e); | |||||
| seq_idx = it; | |||||
| result.first->second = ordered_data_.emplace(pos, e); | |||||
| } | } | ||||
| return std::pair<iterator, bool>(seq_idx, result.second); | |||||
| return {result.first->second, result.second}; | |||||
| } | } | ||||
| // Add an element to the OrderedSet, without judging return value | |||||
| void add(const element_type &e) { (void)insert(ordered_data_.end(), e); } | |||||
| // insert an element to the end of OrderedSet. | |||||
| std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); } | |||||
| void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); } | |||||
| void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); } | |||||
| // Remove an element, if removed return true, otherwise return false | // Remove an element, if removed return true, otherwise return false | ||||
| bool erase(const element_type &e) { | bool erase(const element_type &e) { | ||||
| auto pos = mapped_data_.find(e); | auto pos = mapped_data_.find(e); | ||||
| @@ -109,6 +115,16 @@ class OrderedSet { | |||||
| return true; | return true; | ||||
| } | } | ||||
| iterator erase(iterator pos) { | |||||
| (void)mapped_data_.erase(*pos); | |||||
| return ordered_data_.erase(pos); | |||||
| } | |||||
| iterator erase(const_iterator pos) { | |||||
| (void)mapped_data_.erase(*pos); | |||||
| return ordered_data_.erase(pos); | |||||
| } | |||||
| // Return the container size | // Return the container size | ||||
| std::size_t size() const { return mapped_data_.size(); } | std::size_t size() const { return mapped_data_.size(); } | ||||
| @@ -267,6 +283,22 @@ class OrderedSet { | |||||
| bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); } | bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); } | ||||
| const_iterator find(const element_type &e) const { | |||||
| auto iter = mapped_data_.find(e); | |||||
| if (iter == mapped_data_.end()) { | |||||
| return ordered_data_.end(); | |||||
| } | |||||
| return iter->second; | |||||
| } | |||||
| iterator find(const element_type &e) { | |||||
| auto iter = mapped_data_.find(e); | |||||
| if (iter == mapped_data_.end()) { | |||||
| return ordered_data_.end(); | |||||
| } | |||||
| return iter->second; | |||||
| } | |||||
| // Return the count of an element in set | // Return the count of an element in set | ||||
| std::size_t count(const element_type &e) const { return mapped_data_.count(e); } | std::size_t count(const element_type &e) const { return mapped_data_.count(e); } | ||||