diff --git a/mindinsight/debugger/common/utils.py b/mindinsight/debugger/common/utils.py index 9dceaf58..b4607f9c 100644 --- a/mindinsight/debugger/common/utils.py +++ b/mindinsight/debugger/common/utils.py @@ -145,3 +145,8 @@ def create_view_event_from_tensor_basic_info(tensors_info): def is_scope_type(node_type): """Judge whether the type is scope type.""" return node_type.endswith('scope') + + +def is_cst_type(node_type): + """Judge whether the type is const type.""" + return node_type == NodeTypeEnum.CONST.value diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index 5438167a..ad83af19 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -489,6 +489,7 @@ class DebuggerServer: # get all watchpoint hit list if node_name is None: reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() + reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() return reply graph_name = self.cache_store.get_stream_handler(Streams.GRAPH).validate_graph_name( filter_condition.get('graph_name')) diff --git a/mindinsight/debugger/stream_cache/watchpoint.py b/mindinsight/debugger/stream_cache/watchpoint.py index 4d0e57c6..032eb9af 100644 --- a/mindinsight/debugger/stream_cache/watchpoint.py +++ b/mindinsight/debugger/stream_cache/watchpoint.py @@ -18,7 +18,7 @@ import copy from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError from mindinsight.debugger.common.log import LOGGER as log -from mindinsight.debugger.common.utils import is_scope_type +from mindinsight.debugger.common.utils import is_scope_type, is_cst_type from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition @@ -56,6 +56,7 @@ WATCHPOINT_CONDITION_MAPPING = { class WatchNodeTree: """The WatchNode Node Structure.""" + INVALID = -1 # the scope node and the nodes below are invalid NOT_WATCH = 0 # the scope node and the nodes below are not watched PARTIAL_WATCH = 1 # at least one node under the scope node is not watched TOTAL_WATCH = 2 # the scope node and the nodes below are all watched @@ -234,6 +235,8 @@ class Watchpoint: def get_node_status(self, node_name, node_type, full_name): """Judge if the node is in watch nodes.""" + if is_cst_type(node_type): + return WatchNodeTree.INVALID scope_names = node_name.split('/') cur_node = self._watch_node status = 1 diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py index cb5e319d..bbaffb21 100644 --- a/mindinsight/debugger/stream_handler/watchpoint_handler.py +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -192,7 +192,8 @@ class WatchpointHandler(StreamHandlerBase): int, the number of all watched nodes. """ all_watched_num = 0 - # the state of current node. + valid_node_num = len(nodes) + # initialize the state of current node. state = WatchNodeTree.NOT_WATCH for node in nodes: node_name = node.get('name') @@ -207,10 +208,14 @@ class WatchpointHandler(StreamHandlerBase): if flag == WatchNodeTree.NOT_WATCH: continue state = WatchNodeTree.PARTIAL_WATCH - if flag == WatchNodeTree.TOTAL_WATCH: + if flag == WatchNodeTree.INVALID: + valid_node_num -= 1 + elif flag == WatchNodeTree.TOTAL_WATCH: all_watched_num += 1 - - if all_watched_num == len(nodes): + # update the watch status of current node + if not valid_node_num: + state = WatchNodeTree.INVALID + elif all_watched_num == valid_node_num: state = WatchNodeTree.TOTAL_WATCH return state diff --git a/mindinsight/debugger/stream_operator/watchpoint_operator.py b/mindinsight/debugger/stream_operator/watchpoint_operator.py index fbb3104f..5d81a0b4 100644 --- a/mindinsight/debugger/stream_operator/watchpoint_operator.py +++ b/mindinsight/debugger/stream_operator/watchpoint_operator.py @@ -20,7 +20,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue DebuggerDeleteWatchPointError from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import ServerStatus, \ - Streams + Streams, is_cst_type from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum, TargetTypeEnum from mindinsight.debugger.conditionmgr.recommender import get_basic_node_info from mindinsight.debugger.stream_handler.watchpoint_handler import validate_watch_condition @@ -211,7 +211,7 @@ class WatchpointOperator: cur_node = tmp_queue.get() for node in cur_node.get('nodes'): node_name = node.get('name') - if not target_node_name.startswith(node_name): + if not target_node_name.startswith(node_name) or is_cst_type(node.get('type')): continue if target_node_name == node_name: self._add_leaf_node_collection(node, names) @@ -263,14 +263,14 @@ class WatchpointOperator: def _get_node_basic_infos(self, node_names, graph_name=None): """ - Get node info according to node names. + Get watch node info according to node names. Args: node_names (Union[set[str], list[str]]): A collection of node names. graph_name (str): The relative graph_name of the watched node. Default: None. Returns: - list[NodeBasicInfo], a list of basic node infos. + list[NodeBasicInfo], a list of basic watch nodes info. """ if not node_names: return [] @@ -278,6 +278,7 @@ class WatchpointOperator: node_infos = [] for node_name in node_names: node_info = graph_stream.get_node_basic_info(node_name, graph_name) - node_infos.append(node_info) + if not is_cst_type(node_info.type): + node_infos.append(node_info) return node_infos diff --git a/tests/st/func/debugger/expect_results/restful_results/retrieve_empty_watchpoint_hit_list b/tests/st/func/debugger/expect_results/restful_results/retrieve_empty_watchpoint_hit_list index 23e664b2..32e1ef11 100644 --- a/tests/st/func/debugger/expect_results/restful_results/retrieve_empty_watchpoint_hit_list +++ b/tests/st/func/debugger/expect_results/restful_results/retrieve_empty_watchpoint_hit_list @@ -1 +1 @@ -{"watch_point_hits": []} \ No newline at end of file +{"watch_point_hits": [], "outdated": false} \ No newline at end of file diff --git a/tests/st/func/debugger/expect_results/restful_results/retrieve_watchpoint_hit.json b/tests/st/func/debugger/expect_results/restful_results/retrieve_watchpoint_hit.json index 79bc0fb6..508abddf 100644 --- a/tests/st/func/debugger/expect_results/restful_results/retrieve_watchpoint_hit.json +++ b/tests/st/func/debugger/expect_results/restful_results/retrieve_watchpoint_hit.json @@ -1 +1,47 @@ -{"watch_point_hits": [{"node_name": "Default/TransData-op99", "tensors": [{"slot": "0", "summarized_error_code": 0, "watch_points": [{"id": 1, "watch_condition": {"id": "inf", "params": [], "abbr": "INF"}, "error_code": 0}]}], "graph_name": "graph_0"}, {"node_name": "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op25", "tensors": [{"slot": "0", "summarized_error_code": 0, "watch_points": [{"id": 1, "watch_condition": {"id": "inf", "params": [], "abbr": "INF"}, "error_code": 0}]}], "graph_name": "graph_0"}]} \ No newline at end of file +{ + "watch_point_hits": [ + { + "node_name": "Default/TransData-op99", + "tensors": [ + { + "slot": "0", + "summarized_error_code": 0, + "watch_points": [ + { + "id": 1, + "watch_condition": { + "id": "inf", + "params": [], + "abbr": "INF" + }, + "error_code": 0 + } + ] + } + ], + "graph_name": "graph_0" + }, + { + "node_name": "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op25", + "tensors": [ + { + "slot": "0", + "summarized_error_code": 0, + "watch_points": [ + { + "id": 1, + "watch_condition": { + "id": "inf", + "params": [], + "abbr": "INF" + }, + "error_code": 0 + } + ] + } + ], + "graph_name": "graph_0" + } + ], + "outdated": false +} \ No newline at end of file diff --git a/tests/ut/debugger/test_debugger_server.py b/tests/ut/debugger/test_debugger_server.py index 6cb3b8ff..210df69f 100644 --- a/tests/ut/debugger/test_debugger_server.py +++ b/tests/ut/debugger/test_debugger_server.py @@ -193,7 +193,7 @@ class TestDebuggerServer: self._server.create_watchpoint({'watch_condition': {'id': 'inf'}}) @mock.patch.object(MetadataHandler, 'state', 'waiting') - @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=[MagicMock()]) + @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') @mock.patch.object(WatchpointHandler, 'create_watchpoint') def test_create_watchpoint(self, *args):