|
|
|
@@ -33,19 +33,13 @@ class WatchpointHandler(StreamHandlerBase): |
|
|
|
self._created_watchpoints = [] |
|
|
|
# list of SetCMD of watchpoints to be deleted |
|
|
|
self._deleted_watchpoints = [] |
|
|
|
# dict of <id, SetCMD> of watchpoint to be updated |
|
|
|
# dict of <id, Watchpoint> of watchpoints to be updated |
|
|
|
self._updated_watchpoints = {} |
|
|
|
# the collection of watched node full names, which have been sent to MindSpore |
|
|
|
self._all_watched_node_full_names = set() |
|
|
|
# the collection of new watched node full names, which have not been sent to MindSpore |
|
|
|
self._new_watched_node_full_names = set() |
|
|
|
# record the temp stored nodes in MS, which could be set as watch node for recheck on GPU |
|
|
|
# should be clean at the beginning of each step |
|
|
|
self._temp_cached_node_full_names = set() |
|
|
|
self._latest_id = 0 |
|
|
|
self._cache_set_cmd = {} |
|
|
|
# whether the watchpoint list has been changed since last step |
|
|
|
self.outdated = False |
|
|
|
self._outdated = False |
|
|
|
|
|
|
|
def put(self, value): |
|
|
|
""" |
|
|
|
@@ -61,18 +55,9 @@ class WatchpointHandler(StreamHandlerBase): |
|
|
|
self._latest_id = new_id |
|
|
|
log.debug("Put watchpoint %d into cache.", new_id) |
|
|
|
|
|
|
|
def clean_temp_cached_names(self): |
|
|
|
"""Clean temp cached node.""" |
|
|
|
self._temp_cached_node_full_names.clear() |
|
|
|
|
|
|
|
def add_temp_cached_name(self, node_full_name): |
|
|
|
"""Add temp stored node in cache.""" |
|
|
|
if node_full_name: |
|
|
|
self._temp_cached_node_full_names.add(node_full_name) |
|
|
|
|
|
|
|
def sync_set_cmd(self, set_cmds): |
|
|
|
"""Clean temp watchpoints.""" |
|
|
|
self._new_watched_node_full_names = set() |
|
|
|
self._outdated = False |
|
|
|
self._created_watchpoints = [] |
|
|
|
self._deleted_watchpoints = [] |
|
|
|
self._updated_watchpoints = {} |
|
|
|
@@ -126,20 +111,14 @@ class WatchpointHandler(StreamHandlerBase): |
|
|
|
list[SetCMD], updated watchpoint to be sent to MindSpore. |
|
|
|
""" |
|
|
|
res = [] |
|
|
|
new_watched_nodes = set() |
|
|
|
self._all_watched_node_full_names.clear() |
|
|
|
for _, watchpoint in self._updated_watchpoints.items(): |
|
|
|
# construct set command with leaf nodes |
|
|
|
watch_nodes = watchpoint.get_watch_nodes() |
|
|
|
leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes) |
|
|
|
res.append(watchpoint.get_pending_cmd(leaf_watch_nodes)) |
|
|
|
# update all watched node names |
|
|
|
watch_node_names = [watch_node.full_name for watch_node in [*watch_nodes, *leaf_watch_nodes]] |
|
|
|
new_watched_nodes.update(watch_node_names) |
|
|
|
res.extend(self._deleted_watchpoints) |
|
|
|
for _, set_cmd in self._cache_set_cmd.items(): |
|
|
|
res.append(set_cmd) |
|
|
|
self._all_watched_node_full_names = new_watched_nodes |
|
|
|
return res |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@@ -168,23 +147,14 @@ class WatchpointHandler(StreamHandlerBase): |
|
|
|
leaf_watch_nodes.append(node) |
|
|
|
return leaf_watch_nodes |
|
|
|
|
|
|
|
def is_recheckable(self, backend=None): |
|
|
|
def is_recheckable(self): |
|
|
|
""" |
|
|
|
Check if current status is able to recheck. |
|
|
|
|
|
|
|
Args: |
|
|
|
backend (str): The backend info. 'Ascend' or 'GPU'. Default: None. |
|
|
|
|
|
|
|
Returns: |
|
|
|
bool, if enable to recheck. |
|
|
|
""" |
|
|
|
enable_recheck = self.outdated |
|
|
|
if backend == 'GPU' and enable_recheck: |
|
|
|
# on GPU, disable to recheck if there are new watched node of which the tensor |
|
|
|
# has not been stored on MindSpore |
|
|
|
diff_set = self._new_watched_node_full_names - self._all_watched_node_full_names |
|
|
|
enable_recheck = not diff_set or diff_set.issubset(self._temp_cached_node_full_names) |
|
|
|
return enable_recheck |
|
|
|
return self._outdated |
|
|
|
|
|
|
|
def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None): |
|
|
|
""" |
|
|
|
@@ -274,12 +244,11 @@ class WatchpointHandler(StreamHandlerBase): |
|
|
|
watchpoint = Watchpoint(new_id, watch_condition) |
|
|
|
if watch_nodes: |
|
|
|
watchpoint.add_nodes(watch_nodes) |
|
|
|
self._add_watch_node_in_cache(watch_nodes) |
|
|
|
elif watch_point_id: |
|
|
|
self.validate_watchpoint_id(watch_point_id) |
|
|
|
watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) |
|
|
|
self.put(watchpoint) |
|
|
|
self.outdated = True |
|
|
|
self._outdated = True |
|
|
|
return new_id |
|
|
|
|
|
|
|
def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): |
|
|
|
@@ -296,12 +265,10 @@ class WatchpointHandler(StreamHandlerBase): |
|
|
|
watchpoint = self._watchpoints.get(watch_point_id) |
|
|
|
if watched: |
|
|
|
watchpoint.add_nodes(watch_nodes) |
|
|
|
self._add_watch_node_in_cache(watch_nodes) |
|
|
|
else: |
|
|
|
watchpoint.remove_nodes(watch_nodes) |
|
|
|
self._remove_watch_node_from_cache(watch_nodes) |
|
|
|
self._updated_watchpoints[watch_point_id] = watchpoint |
|
|
|
self.outdated = True |
|
|
|
self._outdated = True |
|
|
|
log.debug("Update watchpoint %d in cache.", watch_point_id) |
|
|
|
|
|
|
|
def delete_watchpoint(self, watch_point_id=None): |
|
|
|
@@ -319,7 +286,7 @@ class WatchpointHandler(StreamHandlerBase): |
|
|
|
watch_point_ids = [watch_point_id] |
|
|
|
for single_id in watch_point_ids: |
|
|
|
self._delete_single_watchpoint(single_id) |
|
|
|
self.outdated = True |
|
|
|
self._outdated = True |
|
|
|
|
|
|
|
def _delete_single_watchpoint(self, watch_point_id): |
|
|
|
""" |
|
|
|
@@ -350,27 +317,6 @@ class WatchpointHandler(StreamHandlerBase): |
|
|
|
log.error("Invalid watchpoint id: %d.", watch_point_id) |
|
|
|
raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) |
|
|
|
|
|
|
|
def _add_watch_node_in_cache(self, watch_nodes): |
|
|
|
""" |
|
|
|
Add watch nodes in cache. |
|
|
|
|
|
|
|
Args: |
|
|
|
watch_nodes (list[NodeBasicInfo]): The list of node basic info. |
|
|
|
""" |
|
|
|
node_full_names = [node.full_name for node in watch_nodes] |
|
|
|
self._new_watched_node_full_names.update(node_full_names) |
|
|
|
|
|
|
|
def _remove_watch_node_from_cache(self, watch_nodes): |
|
|
|
""" |
|
|
|
Remove watch nodes from cache. |
|
|
|
|
|
|
|
Args: |
|
|
|
watch_nodes (list[NodeBasicInfo]): The list of node basic info. |
|
|
|
""" |
|
|
|
for node in watch_nodes: |
|
|
|
if node.full_name in self._new_watched_node_full_names: |
|
|
|
self._new_watched_node_full_names.remove(node.full_name) |
|
|
|
|
|
|
|
|
|
|
|
class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
"""Watchpoint hit handler.""" |
|
|
|
|