Browse Source

!992 Make the code more elegant

From: @maning202007
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
99430b620d
2 changed files with 62 additions and 46 deletions
  1. +53
    -46
      mindinsight/debugger/conditionmgr/recommender.py
  2. +9
    -0
      mindinsight/debugger/stream_cache/watchpoint.py

+ 53
- 46
mindinsight/debugger/conditionmgr/recommender.py View File

@@ -27,11 +27,7 @@ from mindinsight.debugger.conditionmgr.condition import ActivationFuncEnum
from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo
from mindinsight.debugger.conditionmgr.log import logger
from mindinsight.conf import settings


UNSELECTED_STATUS = 0
HALF_SELECTED_STATUS = 1
SELECTED_STATUS = 2
from mindinsight.debugger.stream_cache.watchpoint import WatchNodeTree


class _WatchPointData:
@@ -338,63 +334,74 @@ def _get_basic_node_info_by_node_category(node_category, graph_stream, activatio
return all_graph_nodes


def _merge_nodes(leaf_nodes, graph):
"""merge nodes in one graph"""
unmerged_tree = graph.get_nodes(leaf_nodes)
def _convert_tree_to_node_list(node_tree, node_list):
"""Convert WatchNodeTree to Node list."""
if node_tree.watch_status in [WatchNodeTree.NOT_WATCH, WatchNodeTree.INVALID]:
logger.debug("The watch_status of node: %s is not_watch or invalid.", node_tree.node_name)
return
if node_tree.watch_status == WatchNodeTree.TOTAL_WATCH:
node_basic_info = NodeBasicInfo(name=node_tree.node_name, full_name=node_tree.full_name,
type=node_tree.node_type)
node_list.append(node_basic_info)
return
if node_tree.watch_status == WatchNodeTree.PARTIAL_WATCH:
for _, sub_tree in node_tree.get_children():
_convert_tree_to_node_list(sub_tree, node_list)


def _update_watch_status(node_tree, graph):
"""Update the watch_status, if all sub_nodes of a WatchNodeTree are total_watch,
then the WatchNodeTree is changed to total_watch status."""
tmp_node_queue = Queue.Queue()
tmp_node_queue.put(node_tree)

# watch node list in layer order
watch_nodes = []
for node in unmerged_tree:
if node["type"] != "name_scope":
# if node is leaf_node, it is totally chosen
node["status"] = SELECTED_STATUS
else:
# if node is not leaf_node, it is not chosen initially
node["status"] = UNSELECTED_STATUS
tmp_node_queue.put(node)
watch_tree_list = []
while not tmp_node_queue.empty():
cur_node = tmp_node_queue.get()
watch_nodes.append(cur_node)
for sub_node in cur_node["nodes"]:
if sub_node["type"] != "name_scope":
# if node is leaf_node, it is totally chosen
sub_node["status"] = SELECTED_STATUS
else:
# if node is not leaf_node, it is not chosen initially
sub_node["status"] = UNSELECTED_STATUS
tmp_node_queue.put(sub_node)

merged_watch_nodes = []
while watch_nodes:
cur_node = watch_nodes.pop()
node_name = cur_node["name"]
cur_tree = tmp_node_queue.get()
watch_tree_list.append(cur_tree)
for _, sub_tree in cur_tree.get_children():
tmp_node_queue.put(sub_tree)

# update the watch_status from bottom to top
while watch_tree_list:
cur_tree = watch_tree_list.pop()
node_name = cur_tree.node_name
logger.debug("Update status of node: %s.", node_name)

# if node_name is "", it is the root node, which is not in normal_node_map
if not node_name:
continue
sub_count = graph.normal_node_map.get(node_name).subnode_count
if len(cur_node["nodes"]) < sub_count:

# if the children_count of WatchNodeTree is less than the responding subnode_count in the graph,
# its watch_status must be partial_watch
if cur_tree.get_children_count() < sub_count:
continue
is_all_chosen = True
for sub_node in cur_node["nodes"]:
if sub_node["status"] != SELECTED_STATUS:
for _, sub_tree in cur_tree.get_children():
if sub_tree.watch_status != WatchNodeTree.TOTAL_WATCH:
is_all_chosen = False
break

if is_all_chosen:
cur_node["status"] = SELECTED_STATUS
merged_watch_nodes.append(cur_node)
else:
cur_node["status"] = HALF_SELECTED_STATUS
logger.debug("merged_watch_nodes: %s", merged_watch_nodes)
cur_tree.watch_status = WatchNodeTree.TOTAL_WATCH


def _merge_nodes(leaf_nodes, graph):
"""Merge nodes in one graph."""
watch_node_tree = WatchNodeTree()
for node in leaf_nodes:
watch_node_tree.add_node(node.name, node.type, node.full_name)
_update_watch_status(watch_node_tree, graph)
out_nodes = []
for node_info in merged_watch_nodes:
full_name = graph.get_full_name_by_node_name(node_info["name"])
node_basic_info = NodeBasicInfo(name=node_info["name"], full_name=full_name, type=node_info["type"])
out_nodes.append(node_basic_info)
_convert_tree_to_node_list(watch_node_tree, out_nodes)
logger.debug("out_nodes: %s", out_nodes)
return out_nodes


def _add_graph_name(nodes, graph_stream):
"""add graph_name in node.name"""
"""Add graph_name in node.name."""
if len(graph_stream.graph) > 1:
return nodes
graph_name = graph_stream.graph_names[0]
@@ -407,7 +414,7 @@ def _add_graph_name(nodes, graph_stream):


def _sigmoid(value):
"""calculate the sigmoid of value"""
"""Calculate the sigmoid of value."""
return 1.0 / (1.0 + math.exp(value))




+ 9
- 0
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -84,6 +84,11 @@ class WatchNodeTree:
"""The property of watch status about current node."""
return self._watch_status

@watch_status.setter
def watch_status(self, value):
"""Set the node watch_status."""
self._watch_status = value

def update_metadata(self, node_type, full_name, watch_status):
"""Update the metadata for watched node."""
self._full_name = full_name
@@ -107,6 +112,10 @@ class WatchNodeTree:
for name_scope, sub_watch_node in self._children.items():
yield name_scope, sub_watch_node

def get_children_count(self):
"""Get the count of children nodes."""
return len(self._children)

def add_node(self, node_name, node_type, full_name=''):
"""
Add watch node to watch node tree.


Loading…
Cancel
Save