Browse Source

fix the bugs abourt showing invalid watch node

tags/v1.1.0
yelihua 5 years ago
parent
commit
8e30ec23f0
8 changed files with 74 additions and 13 deletions
  1. +5
    -0
      mindinsight/debugger/common/utils.py
  2. +1
    -0
      mindinsight/debugger/debugger_server.py
  3. +4
    -1
      mindinsight/debugger/stream_cache/watchpoint.py
  4. +9
    -4
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  5. +6
    -5
      mindinsight/debugger/stream_operator/watchpoint_operator.py
  6. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/retrieve_empty_watchpoint_hit_list
  7. +47
    -1
      tests/st/func/debugger/expect_results/restful_results/retrieve_watchpoint_hit.json
  8. +1
    -1
      tests/ut/debugger/test_debugger_server.py

+ 5
- 0
mindinsight/debugger/common/utils.py View File

@@ -145,3 +145,8 @@ def create_view_event_from_tensor_basic_info(tensors_info):
def is_scope_type(node_type):
"""Judge whether the type is scope type."""
return node_type.endswith('scope')


def is_cst_type(node_type):
"""Judge whether the type is const type."""
return node_type == NodeTypeEnum.CONST.value

+ 1
- 0
mindinsight/debugger/debugger_server.py View File

@@ -489,6 +489,7 @@ class DebuggerServer:
# get all watchpoint hit list
if node_name is None:
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get()
reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable()
return reply
graph_name = self.cache_store.get_stream_handler(Streams.GRAPH).validate_graph_name(
filter_condition.get('graph_name'))


+ 4
- 1
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -18,7 +18,7 @@ import copy

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import is_scope_type
from mindinsight.debugger.common.utils import is_scope_type, is_cst_type
from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo
from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum
from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition
@@ -56,6 +56,7 @@ WATCHPOINT_CONDITION_MAPPING = {

class WatchNodeTree:
"""The WatchNode Node Structure."""
INVALID = -1 # the scope node and the nodes below are invalid
NOT_WATCH = 0 # the scope node and the nodes below are not watched
PARTIAL_WATCH = 1 # at least one node under the scope node is not watched
TOTAL_WATCH = 2 # the scope node and the nodes below are all watched
@@ -234,6 +235,8 @@ class Watchpoint:

def get_node_status(self, node_name, node_type, full_name):
"""Judge if the node is in watch nodes."""
if is_cst_type(node_type):
return WatchNodeTree.INVALID
scope_names = node_name.split('/')
cur_node = self._watch_node
status = 1


+ 9
- 4
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -192,7 +192,8 @@ class WatchpointHandler(StreamHandlerBase):
int, the number of all watched nodes.
"""
all_watched_num = 0
# the state of current node.
valid_node_num = len(nodes)
# initialize the state of current node.
state = WatchNodeTree.NOT_WATCH
for node in nodes:
node_name = node.get('name')
@@ -207,10 +208,14 @@ class WatchpointHandler(StreamHandlerBase):
if flag == WatchNodeTree.NOT_WATCH:
continue
state = WatchNodeTree.PARTIAL_WATCH
if flag == WatchNodeTree.TOTAL_WATCH:
if flag == WatchNodeTree.INVALID:
valid_node_num -= 1
elif flag == WatchNodeTree.TOTAL_WATCH:
all_watched_num += 1

if all_watched_num == len(nodes):
# update the watch status of current node
if not valid_node_num:
state = WatchNodeTree.INVALID
elif all_watched_num == valid_node_num:
state = WatchNodeTree.TOTAL_WATCH
return state



+ 6
- 5
mindinsight/debugger/stream_operator/watchpoint_operator.py View File

@@ -20,7 +20,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue
DebuggerDeleteWatchPointError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import ServerStatus, \
Streams
Streams, is_cst_type
from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum, TargetTypeEnum
from mindinsight.debugger.conditionmgr.recommender import get_basic_node_info
from mindinsight.debugger.stream_handler.watchpoint_handler import validate_watch_condition
@@ -211,7 +211,7 @@ class WatchpointOperator:
cur_node = tmp_queue.get()
for node in cur_node.get('nodes'):
node_name = node.get('name')
if not target_node_name.startswith(node_name):
if not target_node_name.startswith(node_name) or is_cst_type(node.get('type')):
continue
if target_node_name == node_name:
self._add_leaf_node_collection(node, names)
@@ -263,14 +263,14 @@ class WatchpointOperator:

def _get_node_basic_infos(self, node_names, graph_name=None):
"""
Get node info according to node names.
Get watch node info according to node names.

Args:
node_names (Union[set[str], list[str]]): A collection of node names.
graph_name (str): The relative graph_name of the watched node. Default: None.

Returns:
list[NodeBasicInfo], a list of basic node infos.
list[NodeBasicInfo], a list of basic watch nodes info.
"""
if not node_names:
return []
@@ -278,6 +278,7 @@ class WatchpointOperator:
node_infos = []
for node_name in node_names:
node_info = graph_stream.get_node_basic_info(node_name, graph_name)
node_infos.append(node_info)
if not is_cst_type(node_info.type):
node_infos.append(node_info)

return node_infos

+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/retrieve_empty_watchpoint_hit_list View File

@@ -1 +1 @@
{"watch_point_hits": []}
{"watch_point_hits": [], "outdated": false}

+ 47
- 1
tests/st/func/debugger/expect_results/restful_results/retrieve_watchpoint_hit.json View File

@@ -1 +1,47 @@
{"watch_point_hits": [{"node_name": "Default/TransData-op99", "tensors": [{"slot": "0", "summarized_error_code": 0, "watch_points": [{"id": 1, "watch_condition": {"id": "inf", "params": [], "abbr": "INF"}, "error_code": 0}]}], "graph_name": "graph_0"}, {"node_name": "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op25", "tensors": [{"slot": "0", "summarized_error_code": 0, "watch_points": [{"id": 1, "watch_condition": {"id": "inf", "params": [], "abbr": "INF"}, "error_code": 0}]}], "graph_name": "graph_0"}]}
{
"watch_point_hits": [
{
"node_name": "Default/TransData-op99",
"tensors": [
{
"slot": "0",
"summarized_error_code": 0,
"watch_points": [
{
"id": 1,
"watch_condition": {
"id": "inf",
"params": [],
"abbr": "INF"
},
"error_code": 0
}
]
}
],
"graph_name": "graph_0"
},
{
"node_name": "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op25",
"tensors": [
{
"slot": "0",
"summarized_error_code": 0,
"watch_points": [
{
"id": 1,
"watch_condition": {
"id": "inf",
"params": [],
"abbr": "INF"
},
"error_code": 0
}
]
}
],
"graph_name": "graph_0"
}
],
"outdated": false
}

+ 1
- 1
tests/ut/debugger/test_debugger_server.py View File

@@ -193,7 +193,7 @@ class TestDebuggerServer:
self._server.create_watchpoint({'watch_condition': {'id': 'inf'}})

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=[MagicMock()])
@mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock())
@mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope')
@mock.patch.object(WatchpointHandler, 'create_watchpoint')
def test_create_watchpoint(self, *args):


Loading…
Cancel
Save