Browse Source

!705 fix the bug of watchpoint hit flag

Merge pull request !705 from yelihua/my-merged-debug
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
413264e672
3 changed files with 20 additions and 8 deletions
  1. +11
    -5
      mindinsight/debugger/debugger_server.py
  2. +7
    -1
      mindinsight/debugger/stream_cache/watchpoint.py
  3. +2
    -2
      mindinsight/debugger/stream_handler/watchpoint_handler.py

+ 11
- 5
mindinsight/debugger/debugger_server.py View File

@@ -177,6 +177,12 @@ class DebuggerServer:
log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', " log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
"'watchpoint_hit', 'tensor'], but got %s.", mode_mapping) "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping)
raise DebuggerParamTypeError("Invalid mode.") raise DebuggerParamTypeError("Invalid mode.")
# validate backend status
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.state == ServerStatus.PENDING.value:
log.info("The backend is in pending status.")
return metadata_stream.get()

filter_condition = {} if filter_condition is None else filter_condition filter_condition = {} if filter_condition is None else filter_condition
reply = mode_mapping[mode](filter_condition) reply = mode_mapping[mode](filter_condition)


@@ -262,12 +268,12 @@ class DebuggerServer:
dict, the tensor history and metadata. dict, the tensor history and metadata.
""" """
log.info("Retrieve tensor history for node: %s.", node_name) log.info("Retrieve tensor history for node: %s.", node_name)
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.state == ServerStatus.PENDING.value:
log.info("The backend is in pending status.")
return metadata_stream.get()
self._validate_leaf_name(node_name) self._validate_leaf_name(node_name)
try:
res = self._get_tensor_history(node_name)
except MindInsightException:
log.warning("Failed to get tensor history for %s.", node_name)
res = {}
res = self._get_tensor_history(node_name)
return res return res


def _validate_leaf_name(self, node_name): def _validate_leaf_name(self, node_name):


+ 7
- 1
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -285,10 +285,16 @@ class WatchpointHit:


@property @property
def tensor_full_name(self): def tensor_full_name(self):
"""The property of tensor_name."""
"""The property of tensor full name."""
tensor_name = ':'.join([self._full_name, self._slot]) tensor_name = ':'.join([self._full_name, self._slot])
return tensor_name return tensor_name


@property
def tensor_name(self):
"""The property of tensor ui name."""
tensor_name = ':'.join([self._node_name, self._slot])
return tensor_name

@property @property
def watchpoint(self): def watchpoint(self):
"""The property of watchpoint.""" """The property of watchpoint."""


+ 2
- 2
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -287,7 +287,7 @@ class WatchpointHitHandler(StreamHandlerBase):
return False return False


for watchpoint_hit in watchpoint_hits: for watchpoint_hit in watchpoint_hits:
if tensor_name == watchpoint_hit.tensor_full_name:
if tensor_name == watchpoint_hit.tensor_name:
return True return True


return False return False
@@ -304,7 +304,7 @@ class WatchpointHitHandler(StreamHandlerBase):


# add hit tensor names to `tensor_names` # add hit tensor names to `tensor_names`
for tensor_info in tensor_history.get('tensor_history'): for tensor_info in tensor_history.get('tensor_history'):
tensor_name = tensor_info['full_name']
tensor_name = tensor_info['name']
hit_flag = self._is_tensor_hit(tensor_name) hit_flag = self._is_tensor_hit(tensor_name)
tensor_info['is_hit'] = hit_flag tensor_info['is_hit'] = hit_flag




Loading…
Cancel
Save