diff --git a/mindinsight/backend/debugger/debugger_api.py b/mindinsight/backend/debugger/debugger_api.py index b5f26aac..fe6c7e46 100644 --- a/mindinsight/backend/debugger/debugger_api.py +++ b/mindinsight/backend/debugger/debugger_api.py @@ -333,6 +333,23 @@ def retrieve_tensor_hits(): return reply +@BLUEPRINT.route("/debugger/search-watchpoint-hits", methods=["POST"]) +def search_watchpoint_hits(): + """ + Search watchpoint hits by group condition. + + Returns: + str, the required data. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/search-watchpoint-hits + """ + body = _read_post_request(request) + group_condition = body.get('group_condition') + reply = _wrap_reply(BACKEND_SERVER.search_watchpoint_hits, group_condition) + return reply + + BACKEND_SERVER = _initialize_debugger_server() diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_grpc_server.py index 3b803e53..169432a5 100644 --- a/mindinsight/debugger/debugger_grpc_server.py +++ b/mindinsight/debugger/debugger_grpc_server.py @@ -191,6 +191,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): for watchpoint_hit in watchpoint_hits: watchpoint_hit_stream.put(watchpoint_hit) watchpoint_hits_info = watchpoint_hit_stream.get() + watchpoint_hits_info.update({'receive_watchpoint_hits': True}) self._cache_store.put_data(watchpoint_hits_info) log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index a79ed77a..1badb368 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -504,6 +504,34 @@ class DebuggerServer: return reply + def search_watchpoint_hits(self, group_condition): + """ + Retrieve watchpoint hit. + + Args: + group_condition (dict): Filter condition. + + - limit (int): The limit of each page. + - offset (int): The offset of current page. + - node_name (str): The retrieved node name. + - graph_name (str): The retrieved graph name. + + Returns: + dict, watch point list or relative graph. + """ + if not isinstance(group_condition, dict): + log.error("Group condition for watchpoint-hits request should be a dict") + raise DebuggerParamTypeError("Group condition for watchpoint-hits request should be a dict") + + metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) + if metadata_stream.state == ServerStatus.PENDING.value: + log.info("The backend is in pending status.") + return metadata_stream.get() + + reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).group_by(group_condition) + reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() + return reply + def create_watchpoint(self, params): """ Create watchpoint. diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py index 9ff6cf4f..8a81ea47 100644 --- a/mindinsight/debugger/stream_handler/watchpoint_handler.py +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -333,12 +333,13 @@ class WatchpointHitHandler(StreamHandlerBase): def __init__(self): # dict of >, - self._hits = {} + self._ordered_hits = [] + self._multi_graph_hits = {} @property def empty(self): """Whether the watchpoint hit is empty.""" - return not self._hits + return not self._multi_graph_hits def put(self, value): """ @@ -351,6 +352,7 @@ class WatchpointHitHandler(StreamHandlerBase): - watchpoint (Watchpoint): The Watchpoint that a node hit. - node_name (str): The UI node name. - graph_name (str): The graph name. + - error_code (int): The code of errors. """ watchpoint_hit = WatchpointHit( tensor_proto=value.get('tensor_proto'), @@ -361,12 +363,12 @@ class WatchpointHitHandler(StreamHandlerBase): if 'error_code' in value.keys(): watchpoint_hit.error_code = value.get('error_code') # get all hit watchpoints according to node name ans tensor slot - watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.node_name, + watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.graph_name, watchpoint_hit.node_name, watchpoint_hit.slot) if watchpoint_hit not in watchpoint_hits: watchpoint_hits.append(watchpoint_hit) - def _get_watchpoints_by_tensor_name(self, node_name, slot): + def _get_watchpoints_by_tensor_name(self, graph_name, node_name, slot): """ Get hit tensors according to ui node name and slot. @@ -377,14 +379,19 @@ class WatchpointHitHandler(StreamHandlerBase): Returns: list, list of watchpoints. """ - hit_node = self._hits.get(node_name) - if hit_node is None: + index = self._multi_graph_hits.get((graph_name, node_name)) + if index is None: hit_node = {} - self._hits[node_name] = hit_node + self._ordered_hits.append(hit_node) + index = len(self._ordered_hits) - 1 + self._multi_graph_hits[(graph_name, node_name)] = index + + hit_node = self._ordered_hits[index] hit_tensors = hit_node.get(slot) if hit_tensors is None: hit_tensors = [] hit_node[slot] = hit_tensors + return hit_tensors def get(self, filter_condition=None): @@ -398,34 +405,108 @@ class WatchpointHitHandler(StreamHandlerBase): Returns: dict, the watchpoint hit list. """ + reply = None if filter_condition is None: log.debug("Get all watchpoint hit list.") reply = self.get_watchpoint_hits() else: log.debug("Get the watchpoint for node: <%s>.", filter_condition) - reply = self._hits.get(filter_condition) + index = self._multi_graph_hits.get(("", filter_condition)) + if index is not None: + reply = self._ordered_hits[index] + return reply + + def group_by(self, group_condition): + """ + Return the watchpoint hits by group condition. + + Args: + group_condition (dict): The group conditions. + - limit (int): The limit number of watchpoint hits each page. + - offset (int): The page offset. + - node_name (str): The node name. + - graph_name (str): The graph name. + + Returns: + dict, the watchpoint hit list. + """ + node_name = group_condition.get('node_name') + # get all watchpoint hit list + if node_name is None: + reply = self._get_by_offset(group_condition) + else: + reply = self._get_by_name(group_condition) return reply + def _get_by_offset(self, group_condition): + """Return the list of watchpoint hits on the offset page.""" + limit = group_condition.get('limit') + offset = group_condition.get('offset') + if not isinstance(limit, int) or not isinstance(offset, int): + log.error("Param limit or offset is not a integer") + raise DebuggerParamValueError("Param limit or offset is not a integer") + watch_point_hits = [] + + total = len(self._ordered_hits) + + if limit * offset >= total and offset != 0: + log.error("Param offset out of bounds") + raise DebuggerParamValueError("Param offset out of bounds") + + if total == 0: + return {} + + for watchpoint_hits in self._ordered_hits[(limit * offset): (limit * (offset + 1))]: + self._get_tensors(watchpoint_hits, watch_point_hits) + + return { + 'watch_point_hits': watch_point_hits, + 'offset': offset, + 'total': total + } + + def _get_by_name(self, group_condition): + """Return the list of watchpoint hits by the group condition.""" + limit = group_condition.get('limit') + if not isinstance(limit, int) or limit == 0: + log.error("Param limit is 0 or not a integer") + raise DebuggerParamValueError("Param limit is 0 or not a integer") + + index = self._multi_graph_hits.get((group_condition.get('graph_name'), group_condition.get('node_name'))) + if index is not None: + group_condition['offset'] = index//limit + return self._get_by_offset(group_condition) + + return {} + def get_watchpoint_hits(self): """Return the list of watchpoint hits.""" watch_point_hits = [] - for node_name, watchpoint_hits in self._hits.items(): - tensors = [] - graph_name = None - for slot, tensor_hits in watchpoint_hits.items(): - if graph_name is None: - graph_name = tensor_hits[0].graph_name - tensor_info = self._get_tensor_hit_info(slot, tensor_hits) - tensors.append(tensor_info) - watch_point_hits.append({ - 'node_name': node_name, - 'tensors': tensors, - 'graph_name': graph_name - }) + for watchpoint_hits in self._ordered_hits: + self._get_tensors(watchpoint_hits, watch_point_hits) return {'watch_point_hits': watch_point_hits} + def _get_tensors(self, watchpoint_hits, watch_point_hits): + """Get the tensors info for the watchpoint_hits.""" + tensors = [] + graph_name = None + node_name = None + for slot, tensor_hits in watchpoint_hits.items(): + if graph_name is None: + graph_name = tensor_hits[0].graph_name + if node_name is None: + node_name = tensor_hits[0].node_name + tensor_info = self._get_tensor_hit_info(slot, tensor_hits) + tensors.append(tensor_info) + + watch_point_hits.append({ + 'node_name': node_name, + 'tensors': tensors, + 'graph_name': graph_name + }) + @staticmethod def _get_tensor_hit_info(slot, tensor_hits): """ @@ -457,19 +538,23 @@ class WatchpointHitHandler(StreamHandlerBase): } return res - def _is_tensor_hit(self, tensor_name): + def _is_tensor_hit(self, tensor_name, graph_name): """ Check if the tensor is record in hit cache. Args: tensor_name (str): The name of ui tensor name. + graph_name (str): The name of ui graph name Returns: bool, if the tensor is hit. """ node_name, slot = tensor_name.rsplit(':', 1) - watchpoint_hits = self._hits.get(node_name, {}).get(slot) - return bool(watchpoint_hits) + index = self._multi_graph_hits.get((graph_name, node_name)) + if index is not None: + watchpoint_hits = self._ordered_hits[index].get(slot) + return bool(watchpoint_hits) + return False def update_tensor_history(self, tensor_history): """ @@ -478,16 +563,17 @@ class WatchpointHitHandler(StreamHandlerBase): Args: tensor_history (dict): The tensor history. """ - if not self._hits: + if not self._multi_graph_hits: return # add hit tensor names to `tensor_names` for tensor_info in tensor_history.get('tensor_history'): tensor_name = tensor_info['name'] - hit_flag = self._is_tensor_hit(tensor_name) + graph_name = tensor_info['graph_name'] + hit_flag = self._is_tensor_hit(tensor_name, graph_name) tensor_info['is_hit'] = hit_flag - def get_tensor_hit_infos(self, tensor_name): + def get_tensor_hit_infos(self, tensor_name, graph_name): """ Get all hit information of a tensor. @@ -498,9 +584,9 @@ class WatchpointHitHandler(StreamHandlerBase): dict, tensor hit info. """ tensor_hit_info = {} - if self._is_tensor_hit(tensor_name): + if self._is_tensor_hit(tensor_name, graph_name): node_name, slot = tensor_name.rsplit(':', 1) - tensor_hits = self._get_watchpoints_by_tensor_name(node_name, slot) + tensor_hits = self._get_watchpoints_by_tensor_name(graph_name, node_name, slot) tensor_hit_info = self._get_tensor_hit_info(slot, tensor_hits) return tensor_hit_info @@ -644,7 +730,7 @@ def _get_error_list(error_code): """ Get error list. Args: - error_code (int): the code of errors. + error_code (int): The code of errors. Returns: list, the error list. diff --git a/mindinsight/debugger/stream_operator/tensor_detail_info.py b/mindinsight/debugger/stream_operator/tensor_detail_info.py index a785478a..72a15cd4 100644 --- a/mindinsight/debugger/stream_operator/tensor_detail_info.py +++ b/mindinsight/debugger/stream_operator/tensor_detail_info.py @@ -77,22 +77,23 @@ class TensorDetailInfo: for node in nodes: node['graph_name'] = graph_name for slot_info in node.get('slots', []): - self._add_watchpoint_hit_info(slot_info, node) + self._add_watchpoint_hit_info(slot_info, node, graph_name) self._add_tensor_info(slot_info, node, missing_tensors) # query missing tensor values from client self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) return graph - def _add_watchpoint_hit_info(self, slot_info, node): + def _add_watchpoint_hit_info(self, slot_info, node, graph_name): """ Add watchpoint hit info for the tensor. Args: slot_info (dict): Slot object. node (dict): Node object. + graph_name (str): Graph name. """ tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) - slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name)) + slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name)) def _add_tensor_info(self, slot_info, node, missing_tensors): """ @@ -141,6 +142,6 @@ class TensorDetailInfo: # validate tensor_name self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) # get watchpoint info that the tensor hit - tensor_hit_info = self._hit_stream.get_tensor_hit_infos(tensor_name) + tensor_hit_info = self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name) watch_points = tensor_hit_info.get('watch_points', []) return watch_points diff --git a/tests/ut/debugger/configurations.py b/tests/ut/debugger/configurations.py index d72bb31e..3a0597b7 100644 --- a/tests/ut/debugger/configurations.py +++ b/tests/ut/debugger/configurations.py @@ -99,6 +99,7 @@ def mock_tensor_history(): "tensor_history": [ {"name": "Default/TransData-op99:0", "full_name": "Default/TransData-op99:0", + "graph_name": "kernel_graph_0", "node_type": "TransData", "type": "output", "step": 0, @@ -108,6 +109,7 @@ def mock_tensor_history(): "value": "click to view"}, {"name": "Default/args0:0", "full_name": "Default/args0:0", + "graph_name": "kernel_graph_0", "node_type": "Parameter", "type": "input", "step": 0, diff --git a/tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json b/tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json index 19390509..8a0ad390 100644 --- a/tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json +++ b/tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json @@ -3,6 +3,7 @@ { "name": "Default/TransData-op99:0", "full_name": "Default/TransData-op99:0", + "graph_name": "kernel_graph_0", "node_type": "TransData", "type": "output", "step": 0, @@ -17,6 +18,7 @@ { "name": "Default/args0:0", "full_name": "Default/args0:0", + "graph_name": "kernel_graph_0", "node_type": "Parameter", "type": "input", "step": 0,