Browse Source

!675 fix the bug about watchpoint id

Merge pull request !675 from yelihua/my-merged-debug
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
6b5f90d3c3
2 changed files with 40 additions and 30 deletions
  1. +37
    -29
      mindinsight/debugger/debugger_server.py
  2. +3
    -1
      mindinsight/debugger/stream_handler/watchpoint_handler.py

+ 37
- 29
mindinsight/debugger/debugger_server.py View File

@@ -45,7 +45,6 @@ class DebuggerServer:
self.grpc_server = DebuggerGrpcServer(self.cache_store) self.grpc_server = DebuggerGrpcServer(self.cache_store)
self.grpc_server_manager = None self.grpc_server_manager = None
self.back_server = None self.back_server = None
self._watch_point_id = 0


def start(self): def start(self):
"""Start server.""" """Start server."""
@@ -95,12 +94,23 @@ class DebuggerServer:


return reply return reply


def search(self, name, watch_point_id):
"""Search for single node in graph."""
def search(self, name, watch_point_id=0):
"""
Search for single node in graph.

Args:
name (str): The name pattern.
watch_point_id (int): The id of watchpoint. Default: 0.

Returns:
dict, the searched nodes.
"""
log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id)
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watch_point_id)
graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name) graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name)
self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes(
graph, watch_point_id)
# add watched label to graph
watchpoint_stream.set_watch_nodes(graph, watch_point_id)
return graph return graph


def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
@@ -155,8 +165,6 @@ class DebuggerServer:
""" """
log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode, log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode,
filter_condition) filter_condition)
# validate watchpoint_id

mode_mapping = { mode_mapping = {
'all': self._retrieve_all, 'all': self._retrieve_all,
'node': self._retrieve_node, 'node': self._retrieve_node,
@@ -178,10 +186,9 @@ class DebuggerServer:
if filter_condition: if filter_condition:
log.error("No filter condition required for retrieve all request.") log.error("No filter condition required for retrieve all request.")
raise DebuggerParamTypeError("filter_condition should be empty.") raise DebuggerParamTypeError("filter_condition should be empty.")
result = {}
self._watch_point_id = 0
self.cache_store.clean_data() self.cache_store.clean_data()
log.info("Clean data queue cache when retrieve all request.") log.info("Clean data queue cache when retrieve all request.")
result = {}
for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]: for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]:
sub_res = self.cache_store.get_stream_handler(stream).get() sub_res = self.cache_store.get_stream_handler(stream).get()
result.update(sub_res) result.update(sub_res)
@@ -200,13 +207,15 @@ class DebuggerServer:
- single_node (bool): If False, return the sub-layer of single node. If True, return - single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node. the node list from root node to single node.


- watch_point_id (int): The id of watchpoint.

Returns: Returns:
dict, the node info.
dict, reply with graph.
""" """
log.info("Retrieve node %s.", filter_condition) log.info("Retrieve node %s.", filter_condition)
# validate node name
node_name = filter_condition.get('name') node_name = filter_condition.get('name')
if node_name: if node_name:
# validate node name
self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name)
filter_condition['single_node'] = bool(filter_condition.get('single_node')) filter_condition['single_node'] = bool(filter_condition.get('single_node'))
reply = self._get_nodes_info(filter_condition) reply = self._get_nodes_info(filter_condition)
@@ -224,16 +233,21 @@ class DebuggerServer:
- single_node (bool): If False, return the sub-layer of single node. If True, return - single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node. the node list from root node to single node.


- watch_point_id (int): The id of watchpoint.

Returns: Returns:
dict, reply with graph. dict, reply with graph.
""" """
# validate watch_point_id
watch_point_id = filter_condition.get('watch_point_id', 0)
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watch_point_id)
# get graph # get graph
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
reply = graph_stream.get(filter_condition) reply = graph_stream.get(filter_condition)
graph = reply.get('graph') graph = reply.get('graph')
# add watched label
self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes(
graph, self._watch_point_id)
# add watched label to graph
watchpoint_stream.set_watch_nodes(graph, watch_point_id)
return reply return reply


def retrieve_tensor_history(self, node_name): def retrieve_tensor_history(self, node_name):
@@ -353,7 +367,7 @@ class DebuggerServer:
Args: Args:
filter_condition (dict): Filter condition. filter_condition (dict): Filter condition.


- watch_point_id (int): The id of watchoint. If not given, return all watchpoints.
- watch_point_id (int): The id of watchpoint. If not given, return all watchpoints.


- name (str): The name of single node. - name (str): The name of single node.


@@ -363,10 +377,7 @@ class DebuggerServer:
Returns: Returns:
dict, watch point list or relative graph. dict, watch point list or relative graph.
""" """
watchpoint_id = filter_condition.get('watch_point_id')
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watchpoint_id)
self._watch_point_id = watchpoint_id if watchpoint_id else 0
watchpoint_id = filter_condition.get('watch_point_id', 0)
if not watchpoint_id: if not watchpoint_id:
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get() reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
log.debug("Get condition of watchpoints.") log.debug("Get condition of watchpoints.")
@@ -392,11 +403,11 @@ class DebuggerServer:
dict, watch point list or relative graph. dict, watch point list or relative graph.
""" """
node_name = filter_condition.get('name') node_name = filter_condition.get('name')
# get watchpoint hit list
# get all watchpoint hit list
if node_name is None: if node_name is None:
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get()
return reply return reply
# get tensor history and graph of the hit node.
self._validate_leaf_name(node_name) self._validate_leaf_name(node_name)
# get tensor history # get tensor history
reply = self._get_tensor_history(node_name) reply = self._get_tensor_history(node_name)
@@ -438,7 +449,6 @@ class DebuggerServer:
watch_nodes = self._get_node_basic_infos(watch_nodes) watch_nodes = self._get_node_basic_infos(watch_nodes)
watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint( watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint(
watch_condition, watch_nodes, watch_point_id) watch_condition, watch_nodes, watch_point_id)
self._watch_point_id = 0
log.info("Create watchpoint %d", watch_point_id) log.info("Create watchpoint %d", watch_point_id)
return {'id': watch_point_id} return {'id': watch_point_id}


@@ -462,7 +472,9 @@ class DebuggerServer:
raise DebuggerUpdateWatchPointError( raise DebuggerUpdateWatchPointError(
"Failed to update watchpoint as the MindSpore is not in waiting state." "Failed to update watchpoint as the MindSpore is not in waiting state."
) )
# validate
# validate parameter
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watch_point_id)
if not watch_nodes or not watch_point_id: if not watch_nodes or not watch_point_id:
log.error("Invalid parameter for update watchpoint.") log.error("Invalid parameter for update watchpoint.")
raise DebuggerParamValueError("Invalid parameter for update watchpoint.") raise DebuggerParamValueError("Invalid parameter for update watchpoint.")
@@ -472,9 +484,7 @@ class DebuggerServer:
elif mode == 1: elif mode == 1:
watch_nodes = self._get_node_basic_infos(watch_nodes) watch_nodes = self._get_node_basic_infos(watch_nodes)


self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint(
watch_point_id, watch_nodes, mode)
self._watch_point_id = watch_point_id
watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, mode)
log.info("Update watchpoint with id: %d", watch_point_id) log.info("Update watchpoint with id: %d", watch_point_id)
return {} return {}


@@ -510,9 +520,7 @@ class DebuggerServer:
raise DebuggerDeleteWatchPointError( raise DebuggerDeleteWatchPointError(
"Failed to delete watchpoint as the MindSpore is not in waiting state." "Failed to delete watchpoint as the MindSpore is not in waiting state."
) )
self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint(
watch_point_id)
self._watch_point_id = 0
self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint(watch_point_id)
log.info("Delete watchpoint with id: %d", watch_point_id) log.info("Delete watchpoint with id: %d", watch_point_id)
return {} return {}




+ 3
- 1
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -98,7 +98,6 @@ class WatchpointHandler(StreamHandlerBase):
""" """
if not (watch_point_id and graph): if not (watch_point_id and graph):
return return
self.validate_watchpoint_id(watch_point_id)
log.debug("add watch flags") log.debug("add watch flags")
watchpoint = self._watchpoints.get(watch_point_id) watchpoint = self._watchpoints.get(watch_point_id)
self._set_watch_status_recursively(graph, watchpoint) self._set_watch_status_recursively(graph, watchpoint)
@@ -192,6 +191,9 @@ class WatchpointHandler(StreamHandlerBase):


def validate_watchpoint_id(self, watch_point_id): def validate_watchpoint_id(self, watch_point_id):
"""Validate watchpoint id.""" """Validate watchpoint id."""
if not isinstance(watch_point_id, int):
log.error("Invalid watchpoint id %s. The watch point id should be int.", watch_point_id)
raise DebuggerParamTypeError("Watchpoint id should be int type.")
if watch_point_id and watch_point_id not in self._watchpoints: if watch_point_id and watch_point_id not in self._watchpoints:
log.error("Invalid watchpoint id: %d.", watch_point_id) log.error("Invalid watchpoint id: %d.", watch_point_id)
raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))


Loading…
Cancel
Save