diff --git a/mindinsight/debugger/conditionmgr/recommender.py b/mindinsight/debugger/conditionmgr/recommender.py index 24788d51..b9952c0f 100644 --- a/mindinsight/debugger/conditionmgr/recommender.py +++ b/mindinsight/debugger/conditionmgr/recommender.py @@ -27,11 +27,7 @@ from mindinsight.debugger.conditionmgr.condition import ActivationFuncEnum from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo from mindinsight.debugger.conditionmgr.log import logger from mindinsight.conf import settings - - -UNSELECTED_STATUS = 0 -HALF_SELECTED_STATUS = 1 -SELECTED_STATUS = 2 +from mindinsight.debugger.stream_cache.watchpoint import WatchNodeTree class _WatchPointData: @@ -338,63 +334,74 @@ def _get_basic_node_info_by_node_category(node_category, graph_stream, activatio return all_graph_nodes -def _merge_nodes(leaf_nodes, graph): - """merge nodes in one graph""" - unmerged_tree = graph.get_nodes(leaf_nodes) +def _convert_tree_to_node_list(node_tree, node_list): + """Convert WatchNodeTree to Node list.""" + if node_tree.watch_status in [WatchNodeTree.NOT_WATCH, WatchNodeTree.INVALID]: + logger.debug("The watch_status of node: %s is not_watch or invalid.", node_tree.node_name) + return + if node_tree.watch_status == WatchNodeTree.TOTAL_WATCH: + node_basic_info = NodeBasicInfo(name=node_tree.node_name, full_name=node_tree.full_name, + type=node_tree.node_type) + node_list.append(node_basic_info) + return + if node_tree.watch_status == WatchNodeTree.PARTIAL_WATCH: + for _, sub_tree in node_tree.get_children(): + _convert_tree_to_node_list(sub_tree, node_list) + + +def _update_watch_status(node_tree, graph): + """Update the watch_status, if all sub_nodes of a WatchNodeTree are total_watch, + then the WatchNodeTree is changed to total_watch status.""" tmp_node_queue = Queue.Queue() + tmp_node_queue.put(node_tree) # watch node list in layer order - watch_nodes = [] - for node in unmerged_tree: - if node["type"] != "name_scope": - # if node is leaf_node, it is totally chosen - node["status"] = SELECTED_STATUS - else: - # if node is not leaf_node, it is not chosen initially - node["status"] = UNSELECTED_STATUS - tmp_node_queue.put(node) + watch_tree_list = [] while not tmp_node_queue.empty(): - cur_node = tmp_node_queue.get() - watch_nodes.append(cur_node) - for sub_node in cur_node["nodes"]: - if sub_node["type"] != "name_scope": - # if node is leaf_node, it is totally chosen - sub_node["status"] = SELECTED_STATUS - else: - # if node is not leaf_node, it is not chosen initially - sub_node["status"] = UNSELECTED_STATUS - tmp_node_queue.put(sub_node) - - merged_watch_nodes = [] - while watch_nodes: - cur_node = watch_nodes.pop() - node_name = cur_node["name"] + cur_tree = tmp_node_queue.get() + watch_tree_list.append(cur_tree) + for _, sub_tree in cur_tree.get_children(): + tmp_node_queue.put(sub_tree) + + # update the watch_status from bottom to top + while watch_tree_list: + cur_tree = watch_tree_list.pop() + node_name = cur_tree.node_name + logger.debug("Update status of node: %s.", node_name) + + # if node_name is "", it is the root node, which is not in normal_node_map + if not node_name: + continue sub_count = graph.normal_node_map.get(node_name).subnode_count - if len(cur_node["nodes"]) < sub_count: + + # if the children_count of WatchNodeTree is less than the responding subnode_count in the graph, + # its watch_status must be partial_watch + if cur_tree.get_children_count() < sub_count: continue is_all_chosen = True - for sub_node in cur_node["nodes"]: - if sub_node["status"] != SELECTED_STATUS: + for _, sub_tree in cur_tree.get_children(): + if sub_tree.watch_status != WatchNodeTree.TOTAL_WATCH: is_all_chosen = False break if is_all_chosen: - cur_node["status"] = SELECTED_STATUS - merged_watch_nodes.append(cur_node) - else: - cur_node["status"] = HALF_SELECTED_STATUS - logger.debug("merged_watch_nodes: %s", merged_watch_nodes) + cur_tree.watch_status = WatchNodeTree.TOTAL_WATCH + + +def _merge_nodes(leaf_nodes, graph): + """Merge nodes in one graph.""" + watch_node_tree = WatchNodeTree() + for node in leaf_nodes: + watch_node_tree.add_node(node.name, node.type, node.full_name) + _update_watch_status(watch_node_tree, graph) out_nodes = [] - for node_info in merged_watch_nodes: - full_name = graph.get_full_name_by_node_name(node_info["name"]) - node_basic_info = NodeBasicInfo(name=node_info["name"], full_name=full_name, type=node_info["type"]) - out_nodes.append(node_basic_info) + _convert_tree_to_node_list(watch_node_tree, out_nodes) logger.debug("out_nodes: %s", out_nodes) return out_nodes def _add_graph_name(nodes, graph_stream): - """add graph_name in node.name""" + """Add graph_name in node.name.""" if len(graph_stream.graph) > 1: return nodes graph_name = graph_stream.graph_names[0] @@ -407,7 +414,7 @@ def _add_graph_name(nodes, graph_stream): def _sigmoid(value): - """calculate the sigmoid of value""" + """Calculate the sigmoid of value.""" return 1.0 / (1.0 + math.exp(value)) diff --git a/mindinsight/debugger/stream_cache/watchpoint.py b/mindinsight/debugger/stream_cache/watchpoint.py index 3163d44a..41131f48 100644 --- a/mindinsight/debugger/stream_cache/watchpoint.py +++ b/mindinsight/debugger/stream_cache/watchpoint.py @@ -84,6 +84,11 @@ class WatchNodeTree: """The property of watch status about current node.""" return self._watch_status + @watch_status.setter + def watch_status(self, value): + """Set the node watch_status.""" + self._watch_status = value + def update_metadata(self, node_type, full_name, watch_status): """Update the metadata for watched node.""" self._full_name = full_name @@ -107,6 +112,10 @@ class WatchNodeTree: for name_scope, sub_watch_node in self._children.items(): yield name_scope, sub_watch_node + def get_children_count(self): + """Get the count of children nodes.""" + return len(self._children) + def add_node(self, node_name, node_type, full_name=''): """ Add watch node to watch node tree.