Browse Source

!1011 fix the issue that too many watchpont_hits shown on UI

From: @jiang-shuqiang
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
752afe0699
7 changed files with 171 additions and 34 deletions
  1. +17
    -0
      mindinsight/backend/debugger/debugger_api.py
  2. +1
    -0
      mindinsight/debugger/debugger_grpc_server.py
  3. +28
    -0
      mindinsight/debugger/debugger_server.py
  4. +116
    -30
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  5. +5
    -4
      mindinsight/debugger/stream_operator/tensor_detail_info.py
  6. +2
    -0
      tests/ut/debugger/configurations.py
  7. +2
    -0
      tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json

+ 17
- 0
mindinsight/backend/debugger/debugger_api.py View File

@@ -333,6 +333,23 @@ def retrieve_tensor_hits():
return reply


@BLUEPRINT.route("/debugger/search-watchpoint-hits", methods=["POST"])
def search_watchpoint_hits():
"""
Search watchpoint hits by group condition.

Returns:
str, the required data.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/search-watchpoint-hits
"""
body = _read_post_request(request)
group_condition = body.get('group_condition')
reply = _wrap_reply(BACKEND_SERVER.search_watchpoint_hits, group_condition)
return reply


BACKEND_SERVER = _initialize_debugger_server()




+ 1
- 0
mindinsight/debugger/debugger_grpc_server.py View File

@@ -191,6 +191,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
for watchpoint_hit in watchpoint_hits:
watchpoint_hit_stream.put(watchpoint_hit)
watchpoint_hits_info = watchpoint_hit_stream.get()
watchpoint_hits_info.update({'receive_watchpoint_hits': True})
self._cache_store.put_data(watchpoint_hits_info)
log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")



+ 28
- 0
mindinsight/debugger/debugger_server.py View File

@@ -504,6 +504,34 @@ class DebuggerServer:

return reply

def search_watchpoint_hits(self, group_condition):
"""
Retrieve watchpoint hit.

Args:
group_condition (dict): Filter condition.

- limit (int): The limit of each page.
- offset (int): The offset of current page.
- node_name (str): The retrieved node name.
- graph_name (str): The retrieved graph name.

Returns:
dict, watch point list or relative graph.
"""
if not isinstance(group_condition, dict):
log.error("Group condition for watchpoint-hits request should be a dict")
raise DebuggerParamTypeError("Group condition for watchpoint-hits request should be a dict")

metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.state == ServerStatus.PENDING.value:
log.info("The backend is in pending status.")
return metadata_stream.get()

reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).group_by(group_condition)
reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable()
return reply

def create_watchpoint(self, params):
"""
Create watchpoint.


+ 116
- 30
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -333,12 +333,13 @@ class WatchpointHitHandler(StreamHandlerBase):

def __init__(self):
# dict of <ui node_name, dict of <slot, WatchpointHit>>,
self._hits = {}
self._ordered_hits = []
self._multi_graph_hits = {}

@property
def empty(self):
"""Whether the watchpoint hit is empty."""
return not self._hits
return not self._multi_graph_hits

def put(self, value):
"""
@@ -351,6 +352,7 @@ class WatchpointHitHandler(StreamHandlerBase):
- watchpoint (Watchpoint): The Watchpoint that a node hit.
- node_name (str): The UI node name.
- graph_name (str): The graph name.
- error_code (int): The code of errors.
"""
watchpoint_hit = WatchpointHit(
tensor_proto=value.get('tensor_proto'),
@@ -361,12 +363,12 @@ class WatchpointHitHandler(StreamHandlerBase):
if 'error_code' in value.keys():
watchpoint_hit.error_code = value.get('error_code')
# get all hit watchpoints according to node name ans tensor slot
watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.node_name,
watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.graph_name, watchpoint_hit.node_name,
watchpoint_hit.slot)
if watchpoint_hit not in watchpoint_hits:
watchpoint_hits.append(watchpoint_hit)

def _get_watchpoints_by_tensor_name(self, node_name, slot):
def _get_watchpoints_by_tensor_name(self, graph_name, node_name, slot):
"""
Get hit tensors according to ui node name and slot.

@@ -377,14 +379,19 @@ class WatchpointHitHandler(StreamHandlerBase):
Returns:
list, list of watchpoints.
"""
hit_node = self._hits.get(node_name)
if hit_node is None:
index = self._multi_graph_hits.get((graph_name, node_name))
if index is None:
hit_node = {}
self._hits[node_name] = hit_node
self._ordered_hits.append(hit_node)
index = len(self._ordered_hits) - 1
self._multi_graph_hits[(graph_name, node_name)] = index

hit_node = self._ordered_hits[index]
hit_tensors = hit_node.get(slot)
if hit_tensors is None:
hit_tensors = []
hit_node[slot] = hit_tensors

return hit_tensors

def get(self, filter_condition=None):
@@ -398,34 +405,108 @@ class WatchpointHitHandler(StreamHandlerBase):
Returns:
dict, the watchpoint hit list.
"""
reply = None
if filter_condition is None:
log.debug("Get all watchpoint hit list.")
reply = self.get_watchpoint_hits()
else:
log.debug("Get the watchpoint for node: <%s>.", filter_condition)
reply = self._hits.get(filter_condition)
index = self._multi_graph_hits.get(("", filter_condition))
if index is not None:
reply = self._ordered_hits[index]
return reply

def group_by(self, group_condition):
"""
Return the watchpoint hits by group condition.

Args:
group_condition (dict): The group conditions.

- limit (int): The limit number of watchpoint hits each page.
- offset (int): The page offset.
- node_name (str): The node name.
- graph_name (str): The graph name.

Returns:
dict, the watchpoint hit list.
"""
node_name = group_condition.get('node_name')
# get all watchpoint hit list
if node_name is None:
reply = self._get_by_offset(group_condition)
else:
reply = self._get_by_name(group_condition)
return reply

def _get_by_offset(self, group_condition):
"""Return the list of watchpoint hits on the offset page."""
limit = group_condition.get('limit')
offset = group_condition.get('offset')
if not isinstance(limit, int) or not isinstance(offset, int):
log.error("Param limit or offset is not a integer")
raise DebuggerParamValueError("Param limit or offset is not a integer")
watch_point_hits = []

total = len(self._ordered_hits)

if limit * offset >= total and offset != 0:
log.error("Param offset out of bounds")
raise DebuggerParamValueError("Param offset out of bounds")

if total == 0:
return {}

for watchpoint_hits in self._ordered_hits[(limit * offset): (limit * (offset + 1))]:
self._get_tensors(watchpoint_hits, watch_point_hits)

return {
'watch_point_hits': watch_point_hits,
'offset': offset,
'total': total
}

def _get_by_name(self, group_condition):
"""Return the list of watchpoint hits by the group condition."""
limit = group_condition.get('limit')
if not isinstance(limit, int) or limit == 0:
log.error("Param limit is 0 or not a integer")
raise DebuggerParamValueError("Param limit is 0 or not a integer")

index = self._multi_graph_hits.get((group_condition.get('graph_name'), group_condition.get('node_name')))
if index is not None:
group_condition['offset'] = index//limit
return self._get_by_offset(group_condition)

return {}

def get_watchpoint_hits(self):
"""Return the list of watchpoint hits."""
watch_point_hits = []
for node_name, watchpoint_hits in self._hits.items():
tensors = []
graph_name = None
for slot, tensor_hits in watchpoint_hits.items():
if graph_name is None:
graph_name = tensor_hits[0].graph_name
tensor_info = self._get_tensor_hit_info(slot, tensor_hits)
tensors.append(tensor_info)
watch_point_hits.append({
'node_name': node_name,
'tensors': tensors,
'graph_name': graph_name
})
for watchpoint_hits in self._ordered_hits:
self._get_tensors(watchpoint_hits, watch_point_hits)

return {'watch_point_hits': watch_point_hits}

def _get_tensors(self, watchpoint_hits, watch_point_hits):
"""Get the tensors info for the watchpoint_hits."""
tensors = []
graph_name = None
node_name = None
for slot, tensor_hits in watchpoint_hits.items():
if graph_name is None:
graph_name = tensor_hits[0].graph_name
if node_name is None:
node_name = tensor_hits[0].node_name
tensor_info = self._get_tensor_hit_info(slot, tensor_hits)
tensors.append(tensor_info)

watch_point_hits.append({
'node_name': node_name,
'tensors': tensors,
'graph_name': graph_name
})

@staticmethod
def _get_tensor_hit_info(slot, tensor_hits):
"""
@@ -457,19 +538,23 @@ class WatchpointHitHandler(StreamHandlerBase):
}
return res

def _is_tensor_hit(self, tensor_name):
def _is_tensor_hit(self, tensor_name, graph_name):
"""
Check if the tensor is record in hit cache.

Args:
tensor_name (str): The name of ui tensor name.
graph_name (str): The name of ui graph name

Returns:
bool, if the tensor is hit.
"""
node_name, slot = tensor_name.rsplit(':', 1)
watchpoint_hits = self._hits.get(node_name, {}).get(slot)
return bool(watchpoint_hits)
index = self._multi_graph_hits.get((graph_name, node_name))
if index is not None:
watchpoint_hits = self._ordered_hits[index].get(slot)
return bool(watchpoint_hits)
return False

def update_tensor_history(self, tensor_history):
"""
@@ -478,16 +563,17 @@ class WatchpointHitHandler(StreamHandlerBase):
Args:
tensor_history (dict): The tensor history.
"""
if not self._hits:
if not self._multi_graph_hits:
return

# add hit tensor names to `tensor_names`
for tensor_info in tensor_history.get('tensor_history'):
tensor_name = tensor_info['name']
hit_flag = self._is_tensor_hit(tensor_name)
graph_name = tensor_info['graph_name']
hit_flag = self._is_tensor_hit(tensor_name, graph_name)
tensor_info['is_hit'] = hit_flag

def get_tensor_hit_infos(self, tensor_name):
def get_tensor_hit_infos(self, tensor_name, graph_name):
"""
Get all hit information of a tensor.

@@ -498,9 +584,9 @@ class WatchpointHitHandler(StreamHandlerBase):
dict, tensor hit info.
"""
tensor_hit_info = {}
if self._is_tensor_hit(tensor_name):
if self._is_tensor_hit(tensor_name, graph_name):
node_name, slot = tensor_name.rsplit(':', 1)
tensor_hits = self._get_watchpoints_by_tensor_name(node_name, slot)
tensor_hits = self._get_watchpoints_by_tensor_name(graph_name, node_name, slot)
tensor_hit_info = self._get_tensor_hit_info(slot, tensor_hits)
return tensor_hit_info

@@ -644,7 +730,7 @@ def _get_error_list(error_code):
"""
Get error list.
Args:
error_code (int): the code of errors.
error_code (int): The code of errors.

Returns:
list, the error list.


+ 5
- 4
mindinsight/debugger/stream_operator/tensor_detail_info.py View File

@@ -77,22 +77,23 @@ class TensorDetailInfo:
for node in nodes:
node['graph_name'] = graph_name
for slot_info in node.get('slots', []):
self._add_watchpoint_hit_info(slot_info, node)
self._add_watchpoint_hit_info(slot_info, node, graph_name)
self._add_tensor_info(slot_info, node, missing_tensors)
# query missing tensor values from client
self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name)
return graph

def _add_watchpoint_hit_info(self, slot_info, node):
def _add_watchpoint_hit_info(self, slot_info, node, graph_name):
"""
Add watchpoint hit info for the tensor.

Args:
slot_info (dict): Slot object.
node (dict): Node object.
graph_name (str): Graph name.
"""
tensor_name = ':'.join([node.get('name'), slot_info.get('slot')])
slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name))
slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name))

def _add_tensor_info(self, slot_info, node, missing_tensors):
"""
@@ -141,6 +142,6 @@ class TensorDetailInfo:
# validate tensor_name
self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name)
# get watchpoint info that the tensor hit
tensor_hit_info = self._hit_stream.get_tensor_hit_infos(tensor_name)
tensor_hit_info = self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name)
watch_points = tensor_hit_info.get('watch_points', [])
return watch_points

+ 2
- 0
tests/ut/debugger/configurations.py View File

@@ -99,6 +99,7 @@ def mock_tensor_history():
"tensor_history": [
{"name": "Default/TransData-op99:0",
"full_name": "Default/TransData-op99:0",
"graph_name": "kernel_graph_0",
"node_type": "TransData",
"type": "output",
"step": 0,
@@ -108,6 +109,7 @@ def mock_tensor_history():
"value": "click to view"},
{"name": "Default/args0:0",
"full_name": "Default/args0:0",
"graph_name": "kernel_graph_0",
"node_type": "Parameter",
"type": "input",
"step": 0,


+ 2
- 0
tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json View File

@@ -3,6 +3,7 @@
{
"name": "Default/TransData-op99:0",
"full_name": "Default/TransData-op99:0",
"graph_name": "kernel_graph_0",
"node_type": "TransData",
"type": "output",
"step": 0,
@@ -17,6 +18,7 @@
{
"name": "Default/args0:0",
"full_name": "Default/args0:0",
"graph_name": "kernel_graph_0",
"node_type": "Parameter",
"type": "input",
"step": 0,


Loading…
Cancel
Save