|
|
|
@@ -71,26 +71,14 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): |
|
|
|
log.warning("No graph received before WaitCMD.") |
|
|
|
reply = get_ack_reply(1) |
|
|
|
return reply |
|
|
|
self._send_received_tensor_tag() |
|
|
|
# send graph if has not been sent before |
|
|
|
# send graph if it has not been sent before |
|
|
|
self._pre_process(request) |
|
|
|
# deal with old command |
|
|
|
reply = self._deal_with_old_command() |
|
|
|
if reply: |
|
|
|
log.info("Reply to WaitCMD with old command: %s", reply) |
|
|
|
return reply |
|
|
|
# continue multiple steps training |
|
|
|
if self._continue_steps: |
|
|
|
reply = get_ack_reply() |
|
|
|
reply.run_cmd.run_steps = 1 |
|
|
|
reply.run_cmd.run_level = 'step' |
|
|
|
self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1 |
|
|
|
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() |
|
|
|
log.debug("Send RunCMD. Clean watchpoint hit.") |
|
|
|
# wait for command |
|
|
|
else: |
|
|
|
# wait for next command |
|
|
|
if reply is None: |
|
|
|
reply = self._wait_for_next_command() |
|
|
|
|
|
|
|
# check the reply |
|
|
|
if reply is None: |
|
|
|
reply = get_ack_reply(1) |
|
|
|
log.warning("Failed to get command event.") |
|
|
|
@@ -98,42 +86,48 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): |
|
|
|
log.info("Reply to WaitCMD: %s", reply) |
|
|
|
return reply |
|
|
|
|
|
|
|
def _send_received_tensor_tag(self): |
|
|
|
"""Send received_finish_tag.""" |
|
|
|
node_name = self._received_view_cmd.get('node_name') |
|
|
|
if not node_name or self._received_view_cmd.get('wait_for_tensor'): |
|
|
|
return |
|
|
|
metadata = self._cache_store.get_stream_handler(Streams.METADATA).get() |
|
|
|
ret = {'receive_tensor': {'node_name': node_name}} |
|
|
|
ret.update(metadata) |
|
|
|
self._cache_store.put_data(ret) |
|
|
|
self._received_view_cmd.clear() |
|
|
|
log.debug("Send receive tensor flag for %s", node_name) |
|
|
|
|
|
|
|
def _pre_process(self, request): |
|
|
|
"""Send graph and metadata when WaitCMD first called.""" |
|
|
|
"""Pre-process before dealing with command.""" |
|
|
|
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) |
|
|
|
is_new_step = metadata_stream.step < request.cur_step |
|
|
|
# clean cache data at the beginning of new step |
|
|
|
if is_new_step: |
|
|
|
self._cache_store.clean_data() |
|
|
|
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) |
|
|
|
# receive graph at the beginning of the training |
|
|
|
if self._status == ServerStatus.RECEIVE_GRAPH: |
|
|
|
self._status = ServerStatus.WAITING |
|
|
|
metadata_stream.state = 'waiting' |
|
|
|
metadata = metadata_stream.get() |
|
|
|
self._cache_store.clean_command() |
|
|
|
res = self._cache_store.get_stream_handler(Streams.GRAPH).get() |
|
|
|
res.update(metadata) |
|
|
|
self._cache_store.put_data(res) |
|
|
|
log.debug("Put graph into data queue.") |
|
|
|
|
|
|
|
if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node: |
|
|
|
# clean tensor cache and DataQueue at the beginning of each step |
|
|
|
self._send_graph_flag(metadata_stream) |
|
|
|
# receive new metadata |
|
|
|
if is_new_step or metadata_stream.full_name != request.cur_node: |
|
|
|
self._update_metadata(metadata_stream, request) |
|
|
|
self._send_received_tensor_tag() |
|
|
|
self._send_watchpoint_hit_flag() |
|
|
|
|
|
|
|
def _send_graph_flag(self, metadata_stream): |
|
|
|
""" |
|
|
|
Send graph and metadata to UI. |
|
|
|
|
|
|
|
Args: |
|
|
|
metadata_stream (MetadataHandler): Metadata handler stream. |
|
|
|
""" |
|
|
|
self._cache_store.clean_command() |
|
|
|
# receive graph in the beginning of the training |
|
|
|
self._status = ServerStatus.WAITING |
|
|
|
metadata_stream.state = 'waiting' |
|
|
|
metadata = metadata_stream.get() |
|
|
|
res = self._cache_store.get_stream_handler(Streams.GRAPH).get() |
|
|
|
res.update(metadata) |
|
|
|
self._cache_store.put_data(res) |
|
|
|
log.debug("Put graph into data queue.") |
|
|
|
|
|
|
|
def _update_metadata(self, metadata_stream, metadata_proto): |
|
|
|
"""Update metadata.""" |
|
|
|
# reset view round and clean cache data |
|
|
|
if metadata_stream.step < metadata_proto.cur_step: |
|
|
|
self._cache_store.clean_data() |
|
|
|
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors( |
|
|
|
metadata_proto.cur_step) |
|
|
|
""" |
|
|
|
Update metadata. |
|
|
|
|
|
|
|
Args: |
|
|
|
metadata_stream (MetadataHandler): Metadata handler stream. |
|
|
|
metadata_proto (MetadataProto): Metadata proto send by client. |
|
|
|
""" |
|
|
|
# put new metadata into cache |
|
|
|
metadata_stream.put(metadata_proto) |
|
|
|
cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name( |
|
|
|
@@ -143,12 +137,41 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): |
|
|
|
self._cache_store.put_data(metadata) |
|
|
|
log.debug("Put new metadata into data queue.") |
|
|
|
|
|
|
|
def _send_received_tensor_tag(self): |
|
|
|
"""Send received_finish_tag.""" |
|
|
|
node_name = self._received_view_cmd.get('node_name') |
|
|
|
if not node_name or self._received_view_cmd.get('wait_for_tensor'): |
|
|
|
return |
|
|
|
metadata = self._cache_store.get_stream_handler(Streams.METADATA).get() |
|
|
|
ret = {'receive_tensor': {'node_name': node_name}} |
|
|
|
ret.update(metadata) |
|
|
|
self._cache_store.put_data(ret) |
|
|
|
self._received_view_cmd.clear() |
|
|
|
log.debug("Send receive tensor flag for %s", node_name) |
|
|
|
|
|
|
|
def _send_watchpoint_hit_flag(self): |
|
|
|
"""Send Watchpoint hit flag.""" |
|
|
|
watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) |
|
|
|
if watchpoint_hit_stream.empty: |
|
|
|
return |
|
|
|
watchpoint_hits_info = watchpoint_hit_stream.get() |
|
|
|
self._cache_store.put_data(watchpoint_hits_info) |
|
|
|
log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") |
|
|
|
|
|
|
|
def _deal_with_old_command(self): |
|
|
|
"""Deal with old command.""" |
|
|
|
event = None |
|
|
|
while self._cache_store.has_command(self._pos) and event is None: |
|
|
|
event = self._get_next_command() |
|
|
|
log.debug("Deal with old %s-th command:\n%s.", self._pos, event) |
|
|
|
# continue multiple steps training |
|
|
|
if event is None and self._continue_steps: |
|
|
|
event = get_ack_reply() |
|
|
|
event.run_cmd.run_steps = 1 |
|
|
|
event.run_cmd.run_level = 'step' |
|
|
|
self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1 |
|
|
|
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() |
|
|
|
log.debug("Send RunCMD. Clean watchpoint hit.") |
|
|
|
|
|
|
|
return event |
|
|
|
|
|
|
|
@@ -295,8 +318,5 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): |
|
|
|
'node_name': ui_node_name |
|
|
|
} |
|
|
|
watchpoint_hit_stream.put(watchpoint_hit) |
|
|
|
watchpoint_hits_info = watchpoint_hit_stream.get() |
|
|
|
self._cache_store.put_data(watchpoint_hits_info) |
|
|
|
log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") |
|
|
|
reply = get_ack_reply() |
|
|
|
return reply |