Browse Source

change the behavior of recheck

tags/v1.1.0
yelihua 5 years ago
parent
commit
8e9c0e2acf
5 changed files with 24 additions and 76 deletions
  1. +0
    -4
      mindinsight/debugger/debugger_grpc_server.py
  2. +12
    -6
      mindinsight/debugger/stream_cache/watchpoint.py
  3. +8
    -62
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  4. +2
    -2
      mindinsight/debugger/stream_operator/watchpoint_operator.py
  5. +2
    -2
      tests/st/func/debugger/test_restful_api.py

+ 0
- 4
mindinsight/debugger/debugger_grpc_server.py View File

@@ -99,7 +99,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
def _pre_process(self, request):
"""Pre-process before dealing with command."""
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
is_new_step = metadata_stream.step < request.cur_step
is_new_node = metadata_stream.full_name != request.cur_node
# clean cache data at the beginning of new step or node has been changed.
@@ -108,15 +107,12 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
if is_new_step:
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
watchpoint_stream.clean_temp_cached_names()
# receive graph at the beginning of the training
if self._status == ServerStatus.RECEIVE_GRAPH:
self._send_graph_flag(metadata_stream)
# receive new metadata
if is_new_step or is_new_node:
self._update_metadata(metadata_stream, request)
# save the full name of the node which MindSpore has stored the tensor.
watchpoint_stream.add_temp_cached_name(request.cur_node)
self._send_received_tensor_tag()
self._send_watchpoint_hit_flag()



+ 12
- 6
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -13,13 +13,15 @@
# limitations under the License.
# ============================================================================
"""Define the watchpoint stream."""
from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo

import copy

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import is_scope_type
from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition
from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo
from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum
from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition

WATCHPOINT_CONDITION_MAPPING = {
ConditionIdEnum.NAN.value: WatchCondition.Condition.nan,
@@ -109,7 +111,7 @@ class WatchNodeTree:
return self._children.get(sub_name)

def get_children(self):
"""Get all childrens."""
"""Get all children."""
for name_scope, sub_watch_node in self._children.items():
yield name_scope, sub_watch_node

@@ -198,13 +200,17 @@ class Watchpoint:
"""The property of watch condition."""
return self._condition

def copy_nodes_from(self, other_watchpoint):
def copy_nodes_from(self, other_watchpoint, deep_copy=False):
"""
Copy nodes from other watchpoint.
Args:
other_watchpoint (Watchpoint): Other watchpoint.
deep_copy (bool): Whether using deepcopy.
"""
self._watch_node = other_watchpoint.nodes
if deep_copy:
self._watch_node = copy.deepcopy(other_watchpoint.nodes)
else:
self._watch_node = other_watchpoint.nodes

def add_nodes(self, nodes):
"""Add node into watchpoint."""


+ 8
- 62
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -33,19 +33,13 @@ class WatchpointHandler(StreamHandlerBase):
self._created_watchpoints = []
# list of SetCMD of watchpoints to be deleted
self._deleted_watchpoints = []
# dict of <id, SetCMD> of watchpoint to be updated
# dict of <id, Watchpoint> of watchpoints to be updated
self._updated_watchpoints = {}
# the collection of watched node full names, which have been sent to MindSpore
self._all_watched_node_full_names = set()
# the collection of new watched node full names, which have not been sent to MindSpore
self._new_watched_node_full_names = set()
# record the temp stored nodes in MS, which could be set as watch node for recheck on GPU
# should be clean at the beginning of each step
self._temp_cached_node_full_names = set()
self._latest_id = 0
self._cache_set_cmd = {}
# whether the watchpoint list has been changed since last step
self.outdated = False
self._outdated = False

def put(self, value):
"""
@@ -61,18 +55,9 @@ class WatchpointHandler(StreamHandlerBase):
self._latest_id = new_id
log.debug("Put watchpoint %d into cache.", new_id)

def clean_temp_cached_names(self):
"""Clean temp cached node."""
self._temp_cached_node_full_names.clear()

def add_temp_cached_name(self, node_full_name):
"""Add temp stored node in cache."""
if node_full_name:
self._temp_cached_node_full_names.add(node_full_name)

def sync_set_cmd(self, set_cmds):
"""Clean temp watchpoints."""
self._new_watched_node_full_names = set()
self._outdated = False
self._created_watchpoints = []
self._deleted_watchpoints = []
self._updated_watchpoints = {}
@@ -126,20 +111,14 @@ class WatchpointHandler(StreamHandlerBase):
list[SetCMD], updated watchpoint to be sent to MindSpore.
"""
res = []
new_watched_nodes = set()
self._all_watched_node_full_names.clear()
for _, watchpoint in self._updated_watchpoints.items():
# construct set command with leaf nodes
watch_nodes = watchpoint.get_watch_nodes()
leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes)
res.append(watchpoint.get_pending_cmd(leaf_watch_nodes))
# update all watched node names
watch_node_names = [watch_node.full_name for watch_node in [*watch_nodes, *leaf_watch_nodes]]
new_watched_nodes.update(watch_node_names)
res.extend(self._deleted_watchpoints)
for _, set_cmd in self._cache_set_cmd.items():
res.append(set_cmd)
self._all_watched_node_full_names = new_watched_nodes
return res

@staticmethod
@@ -168,23 +147,14 @@ class WatchpointHandler(StreamHandlerBase):
leaf_watch_nodes.append(node)
return leaf_watch_nodes

def is_recheckable(self, backend=None):
def is_recheckable(self):
"""
Check if current status is able to recheck.

Args:
backend (str): The backend info. 'Ascend' or 'GPU'. Default: None.

Returns:
bool, if enable to recheck.
"""
enable_recheck = self.outdated
if backend == 'GPU' and enable_recheck:
# on GPU, disable to recheck if there are new watched node of which the tensor
# has not been stored on MindSpore
diff_set = self._new_watched_node_full_names - self._all_watched_node_full_names
enable_recheck = not diff_set or diff_set.issubset(self._temp_cached_node_full_names)
return enable_recheck
return self._outdated

def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None):
"""
@@ -274,12 +244,11 @@ class WatchpointHandler(StreamHandlerBase):
watchpoint = Watchpoint(new_id, watch_condition)
if watch_nodes:
watchpoint.add_nodes(watch_nodes)
self._add_watch_node_in_cache(watch_nodes)
elif watch_point_id:
self.validate_watchpoint_id(watch_point_id)
watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
self.put(watchpoint)
self.outdated = True
self._outdated = True
return new_id

def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
@@ -296,12 +265,10 @@ class WatchpointHandler(StreamHandlerBase):
watchpoint = self._watchpoints.get(watch_point_id)
if watched:
watchpoint.add_nodes(watch_nodes)
self._add_watch_node_in_cache(watch_nodes)
else:
watchpoint.remove_nodes(watch_nodes)
self._remove_watch_node_from_cache(watch_nodes)
self._updated_watchpoints[watch_point_id] = watchpoint
self.outdated = True
self._outdated = True
log.debug("Update watchpoint %d in cache.", watch_point_id)

def delete_watchpoint(self, watch_point_id=None):
@@ -319,7 +286,7 @@ class WatchpointHandler(StreamHandlerBase):
watch_point_ids = [watch_point_id]
for single_id in watch_point_ids:
self._delete_single_watchpoint(single_id)
self.outdated = True
self._outdated = True

def _delete_single_watchpoint(self, watch_point_id):
"""
@@ -350,27 +317,6 @@ class WatchpointHandler(StreamHandlerBase):
log.error("Invalid watchpoint id: %d.", watch_point_id)
raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))

def _add_watch_node_in_cache(self, watch_nodes):
"""
Add watch nodes in cache.

Args:
watch_nodes (list[NodeBasicInfo]): The list of node basic info.
"""
node_full_names = [node.full_name for node in watch_nodes]
self._new_watched_node_full_names.update(node_full_names)

def _remove_watch_node_from_cache(self, watch_nodes):
"""
Remove watch nodes from cache.

Args:
watch_nodes (list[NodeBasicInfo]): The list of node basic info.
"""
for node in watch_nodes:
if node.full_name in self._new_watched_node_full_names:
self._new_watched_node_full_names.remove(node.full_name)


class WatchpointHitHandler(StreamHandlerBase):
"""Watchpoint hit handler."""


+ 2
- 2
mindinsight/debugger/stream_operator/watchpoint_operator.py View File

@@ -87,7 +87,7 @@ class WatchpointOperator:
self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id'))
log.info("Create watchpoint %d", watch_point_id)

metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend)
metadata_stream.enable_recheck = watchpoint_stream.is_recheckable()
res = metadata_stream.get(['state', 'enable_recheck'])
res['id'] = watch_point_id
return res
@@ -140,7 +140,7 @@ class WatchpointOperator:
search_pattern=params.get('search_pattern'),
graph_name=params.get('graph_name'))
watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode'))
metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend)
metadata_stream.enable_recheck = watchpoint_stream.is_recheckable()
log.info("Update watchpoint with id: %d", watch_point_id)
return metadata_stream.get(['state', 'enable_recheck'])



+ 2
- 2
tests/st/func/debugger/test_restful_api.py View File

@@ -434,7 +434,7 @@ class TestGPUDebugger:
@pytest.mark.parametrize("url, body_data, enable_recheck", [
('create_watchpoint',
{'condition': {'id': 'inf', 'params': []},
'watch_nodes': ['Default']}, False),
'watch_nodes': ['Default']}, True),
('create_watchpoint',
{'condition': {'id': 'inf', 'params': []},
'watch_nodes': ['Default/TransData-op99']}, True),
@@ -443,7 +443,7 @@ class TestGPUDebugger:
'mode': 0}, True),
('update_watchpoint',
{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
'mode': 1}, False),
'mode': 1}, True),
('update_watchpoint',
[{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
'mode': 1},


Loading…
Cancel
Save