Browse Source

fix the updating for watchpoint hit

tags/v1.0.0
yelihua 5 years ago
parent
commit
2a0546186c
3 changed files with 81 additions and 52 deletions
  1. +69
    -49
      mindinsight/debugger/debugger_grpc_server.py
  2. +6
    -2
      mindinsight/debugger/debugger_server.py
  3. +6
    -1
      mindinsight/debugger/stream_handler/watchpoint_handler.py

+ 69
- 49
mindinsight/debugger/debugger_grpc_server.py View File

@@ -71,26 +71,14 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.warning("No graph received before WaitCMD.") log.warning("No graph received before WaitCMD.")
reply = get_ack_reply(1) reply = get_ack_reply(1)
return reply return reply
self._send_received_tensor_tag()
# send graph if has not been sent before
# send graph if it has not been sent before
self._pre_process(request) self._pre_process(request)
# deal with old command # deal with old command
reply = self._deal_with_old_command() reply = self._deal_with_old_command()
if reply:
log.info("Reply to WaitCMD with old command: %s", reply)
return reply
# continue multiple steps training
if self._continue_steps:
reply = get_ack_reply()
reply.run_cmd.run_steps = 1
reply.run_cmd.run_level = 'step'
self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
log.debug("Send RunCMD. Clean watchpoint hit.")
# wait for command
else:
# wait for next command
if reply is None:
reply = self._wait_for_next_command() reply = self._wait_for_next_command()
# check the reply
if reply is None: if reply is None:
reply = get_ack_reply(1) reply = get_ack_reply(1)
log.warning("Failed to get command event.") log.warning("Failed to get command event.")
@@ -98,42 +86,48 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.info("Reply to WaitCMD: %s", reply) log.info("Reply to WaitCMD: %s", reply)
return reply return reply


def _send_received_tensor_tag(self):
"""Send received_finish_tag."""
node_name = self._received_view_cmd.get('node_name')
if not node_name or self._received_view_cmd.get('wait_for_tensor'):
return
metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
ret = {'receive_tensor': {'node_name': node_name}}
ret.update(metadata)
self._cache_store.put_data(ret)
self._received_view_cmd.clear()
log.debug("Send receive tensor flag for %s", node_name)

def _pre_process(self, request): def _pre_process(self, request):
"""Send graph and metadata when WaitCMD first called."""
"""Pre-process before dealing with command."""
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
is_new_step = metadata_stream.step < request.cur_step
# clean cache data at the beginning of new step
if is_new_step:
self._cache_store.clean_data()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
# receive graph at the beginning of the training
if self._status == ServerStatus.RECEIVE_GRAPH: if self._status == ServerStatus.RECEIVE_GRAPH:
self._status = ServerStatus.WAITING
metadata_stream.state = 'waiting'
metadata = metadata_stream.get()
self._cache_store.clean_command()
res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
res.update(metadata)
self._cache_store.put_data(res)
log.debug("Put graph into data queue.")

if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node:
# clean tensor cache and DataQueue at the beginning of each step
self._send_graph_flag(metadata_stream)
# receive new metadata
if is_new_step or metadata_stream.full_name != request.cur_node:
self._update_metadata(metadata_stream, request) self._update_metadata(metadata_stream, request)
self._send_received_tensor_tag()
self._send_watchpoint_hit_flag()

def _send_graph_flag(self, metadata_stream):
"""
Send graph and metadata to UI.

Args:
metadata_stream (MetadataHandler): Metadata handler stream.
"""
self._cache_store.clean_command()
# receive graph in the beginning of the training
self._status = ServerStatus.WAITING
metadata_stream.state = 'waiting'
metadata = metadata_stream.get()
res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
res.update(metadata)
self._cache_store.put_data(res)
log.debug("Put graph into data queue.")


def _update_metadata(self, metadata_stream, metadata_proto): def _update_metadata(self, metadata_stream, metadata_proto):
"""Update metadata."""
# reset view round and clean cache data
if metadata_stream.step < metadata_proto.cur_step:
self._cache_store.clean_data()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(
metadata_proto.cur_step)
"""
Update metadata.

Args:
metadata_stream (MetadataHandler): Metadata handler stream.
metadata_proto (MetadataProto): Metadata proto send by client.
"""
# put new metadata into cache # put new metadata into cache
metadata_stream.put(metadata_proto) metadata_stream.put(metadata_proto)
cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name( cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name(
@@ -143,12 +137,41 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store.put_data(metadata) self._cache_store.put_data(metadata)
log.debug("Put new metadata into data queue.") log.debug("Put new metadata into data queue.")


def _send_received_tensor_tag(self):
"""Send received_finish_tag."""
node_name = self._received_view_cmd.get('node_name')
if not node_name or self._received_view_cmd.get('wait_for_tensor'):
return
metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
ret = {'receive_tensor': {'node_name': node_name}}
ret.update(metadata)
self._cache_store.put_data(ret)
self._received_view_cmd.clear()
log.debug("Send receive tensor flag for %s", node_name)

def _send_watchpoint_hit_flag(self):
"""Send Watchpoint hit flag."""
watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
if watchpoint_hit_stream.empty:
return
watchpoint_hits_info = watchpoint_hit_stream.get()
self._cache_store.put_data(watchpoint_hits_info)
log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")

def _deal_with_old_command(self): def _deal_with_old_command(self):
"""Deal with old command.""" """Deal with old command."""
event = None event = None
while self._cache_store.has_command(self._pos) and event is None: while self._cache_store.has_command(self._pos) and event is None:
event = self._get_next_command() event = self._get_next_command()
log.debug("Deal with old %s-th command:\n%s.", self._pos, event) log.debug("Deal with old %s-th command:\n%s.", self._pos, event)
# continue multiple steps training
if event is None and self._continue_steps:
event = get_ack_reply()
event.run_cmd.run_steps = 1
event.run_cmd.run_level = 'step'
self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
log.debug("Send RunCMD. Clean watchpoint hit.")


return event return event


@@ -295,8 +318,5 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
'node_name': ui_node_name 'node_name': ui_node_name
} }
watchpoint_hit_stream.put(watchpoint_hit) watchpoint_hit_stream.put(watchpoint_hit)
watchpoint_hits_info = watchpoint_hit_stream.get()
self._cache_store.put_data(watchpoint_hits_info)
log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")
reply = get_ack_reply() reply = get_ack_reply()
return reply return reply

+ 6
- 2
mindinsight/debugger/debugger_server.py View File

@@ -212,7 +212,7 @@ class DebuggerServer:
Returns: Returns:
dict, reply with graph. dict, reply with graph.
""" """
log.info("Retrieve node %s.", filter_condition)
log.debug("Retrieve node %s.", filter_condition)
# validate node name # validate node name
node_name = filter_condition.get('name') node_name = filter_condition.get('name')
if node_name: if node_name:
@@ -262,7 +262,11 @@ class DebuggerServer:
""" """
log.info("Retrieve tensor history for node: %s.", node_name) log.info("Retrieve tensor history for node: %s.", node_name)
self._validate_leaf_name(node_name) self._validate_leaf_name(node_name)
res = self._get_tensor_history(node_name)
try:
res = self._get_tensor_history(node_name)
except MindInsightException:
log.warning("Failed to get tensor history for %s.", node_name)
res = {}
return res return res


def _validate_leaf_name(self, node_name): def _validate_leaf_name(self, node_name):


+ 6
- 1
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -205,6 +205,11 @@ class WatchpointHitHandler(StreamHandlerBase):
def __init__(self): def __init__(self):
self._hits = {} self._hits = {}


@property
def empty(self):
"""Whether the watchpoint hit is empty."""
return not self._hits

def put(self, value): def put(self, value):
""" """
Put value into watchpoint hit cache. Called by grpc server. Put value into watchpoint hit cache. Called by grpc server.
@@ -235,7 +240,7 @@ class WatchpointHitHandler(StreamHandlerBase):
Get watchpoint hit list. Get watchpoint hit list.


Args: Args:
filter_condition (str): Get the watchpoint hit according to specifiled node name.
filter_condition (str): Get the watchpoint hit according to specified node name.
If not given, get all watchpoint hits. Default: None. If not given, get all watchpoint hits. Default: None.


Returns: Returns:


Loading…
Cancel
Save