diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index 4c7c01f0..bbc3b06c 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -45,7 +45,6 @@ class DebuggerServer: self.grpc_server = DebuggerGrpcServer(self.cache_store) self.grpc_server_manager = None self.back_server = None - self._watch_point_id = 0 def start(self): """Start server.""" @@ -95,12 +94,23 @@ class DebuggerServer: return reply - def search(self, name, watch_point_id): - """Search for single node in graph.""" + def search(self, name, watch_point_id=0): + """ + Search for single node in graph. + + Args: + name (str): The name pattern. + watch_point_id (int): The id of watchpoint. Default: 0. + + Returns: + dict, the searched nodes. + """ log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) + watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) + watchpoint_stream.validate_watchpoint_id(watch_point_id) graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name) - self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( - graph, watch_point_id) + # add watched label to graph + watchpoint_stream.set_watch_nodes(graph, watch_point_id) return graph def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): @@ -155,8 +165,6 @@ class DebuggerServer: """ log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode, filter_condition) - # validate watchpoint_id - mode_mapping = { 'all': self._retrieve_all, 'node': self._retrieve_node, @@ -178,10 +186,9 @@ class DebuggerServer: if filter_condition: log.error("No filter condition required for retrieve all request.") raise DebuggerParamTypeError("filter_condition should be empty.") - result = {} - self._watch_point_id = 0 self.cache_store.clean_data() log.info("Clean data queue cache when retrieve all request.") + result = {} for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]: sub_res = self.cache_store.get_stream_handler(stream).get() result.update(sub_res) @@ -200,13 +207,15 @@ class DebuggerServer: - single_node (bool): If False, return the sub-layer of single node. If True, return the node list from root node to single node. + - watch_point_id (int): The id of watchpoint. + Returns: - dict, the node info. + dict, reply with graph. """ log.info("Retrieve node %s.", filter_condition) + # validate node name node_name = filter_condition.get('name') if node_name: - # validate node name self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) filter_condition['single_node'] = bool(filter_condition.get('single_node')) reply = self._get_nodes_info(filter_condition) @@ -224,16 +233,21 @@ class DebuggerServer: - single_node (bool): If False, return the sub-layer of single node. If True, return the node list from root node to single node. + - watch_point_id (int): The id of watchpoint. + Returns: dict, reply with graph. """ + # validate watch_point_id + watch_point_id = filter_condition.get('watch_point_id', 0) + watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) + watchpoint_stream.validate_watchpoint_id(watch_point_id) # get graph graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) reply = graph_stream.get(filter_condition) graph = reply.get('graph') - # add watched label - self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( - graph, self._watch_point_id) + # add watched label to graph + watchpoint_stream.set_watch_nodes(graph, watch_point_id) return reply def retrieve_tensor_history(self, node_name): @@ -353,7 +367,7 @@ class DebuggerServer: Args: filter_condition (dict): Filter condition. - - watch_point_id (int): The id of watchoint. If not given, return all watchpoints. + - watch_point_id (int): The id of watchpoint. If not given, return all watchpoints. - name (str): The name of single node. @@ -363,10 +377,7 @@ class DebuggerServer: Returns: dict, watch point list or relative graph. """ - watchpoint_id = filter_condition.get('watch_point_id') - watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) - watchpoint_stream.validate_watchpoint_id(watchpoint_id) - self._watch_point_id = watchpoint_id if watchpoint_id else 0 + watchpoint_id = filter_condition.get('watch_point_id', 0) if not watchpoint_id: reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get() log.debug("Get condition of watchpoints.") @@ -392,11 +403,11 @@ class DebuggerServer: dict, watch point list or relative graph. """ node_name = filter_condition.get('name') - # get watchpoint hit list + # get all watchpoint hit list if node_name is None: reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() return reply - + # get tensor history and graph of the hit node. self._validate_leaf_name(node_name) # get tensor history reply = self._get_tensor_history(node_name) @@ -438,7 +449,6 @@ class DebuggerServer: watch_nodes = self._get_node_basic_infos(watch_nodes) watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint( watch_condition, watch_nodes, watch_point_id) - self._watch_point_id = 0 log.info("Create watchpoint %d", watch_point_id) return {'id': watch_point_id} @@ -462,7 +472,9 @@ class DebuggerServer: raise DebuggerUpdateWatchPointError( "Failed to update watchpoint as the MindSpore is not in waiting state." ) - # validate + # validate parameter + watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) + watchpoint_stream.validate_watchpoint_id(watch_point_id) if not watch_nodes or not watch_point_id: log.error("Invalid parameter for update watchpoint.") raise DebuggerParamValueError("Invalid parameter for update watchpoint.") @@ -472,9 +484,7 @@ class DebuggerServer: elif mode == 1: watch_nodes = self._get_node_basic_infos(watch_nodes) - self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint( - watch_point_id, watch_nodes, mode) - self._watch_point_id = watch_point_id + watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, mode) log.info("Update watchpoint with id: %d", watch_point_id) return {} @@ -510,9 +520,7 @@ class DebuggerServer: raise DebuggerDeleteWatchPointError( "Failed to delete watchpoint as the MindSpore is not in waiting state." ) - self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint( - watch_point_id) - self._watch_point_id = 0 + self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint(watch_point_id) log.info("Delete watchpoint with id: %d", watch_point_id) return {} diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py index 0c91e1be..0665f437 100644 --- a/mindinsight/debugger/stream_handler/watchpoint_handler.py +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -98,7 +98,6 @@ class WatchpointHandler(StreamHandlerBase): """ if not (watch_point_id and graph): return - self.validate_watchpoint_id(watch_point_id) log.debug("add watch flags") watchpoint = self._watchpoints.get(watch_point_id) self._set_watch_status_recursively(graph, watchpoint) @@ -192,6 +191,9 @@ class WatchpointHandler(StreamHandlerBase): def validate_watchpoint_id(self, watch_point_id): """Validate watchpoint id.""" + if not isinstance(watch_point_id, int): + log.error("Invalid watchpoint id %s. The watch point id should be int.", watch_point_id) + raise DebuggerParamTypeError("Watchpoint id should be int type.") if watch_point_id and watch_point_id not in self._watchpoints: log.error("Invalid watchpoint id: %d.", watch_point_id) raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))