Browse Source

add sending state in debugger

tags/v1.1.0
yelihua 5 years ago
parent
commit
cb816c3efd
6 changed files with 34 additions and 21 deletions
  1. +2
    -1
      mindinsight/debugger/common/utils.py
  2. +16
    -8
      mindinsight/debugger/debugger_grpc_server.py
  3. +2
    -0
      mindinsight/debugger/stream_handler/event_handler.py
  4. +5
    -4
      mindinsight/debugger/stream_operator/training_control_operator.py
  5. +6
    -5
      tests/st/func/debugger/test_restful_api.py
  6. +3
    -3
      tests/ut/debugger/stream_operator/test_training_control_operator.py

+ 2
- 1
mindinsight/debugger/common/utils.py View File

@@ -56,7 +56,8 @@ class ServerStatus(enum.Enum):
RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph
WAITING = 'waiting' # the client session is ready
RUNNING = 'running' # the client session is running a script
MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched
MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched
SENDING = 'sending' # the request is in cache but not be sent to client


@enum.unique


+ 16
- 8
mindinsight/debugger/debugger_grpc_server.py View File

@@ -104,7 +104,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):

# check if version is mismatch, if mismatch, send mismatch info to UI
if self._status == ServerStatus.MISMATCH:
log.warning("Version of Mindspore and Mindinsight re unmatched,"
log.warning("Version of MindSpore and MindInsight are unmatched,"
"waiting for user to terminate the script.")
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
# put metadata into data queue
@@ -140,7 +140,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store.clean_command()
# receive graph in the beginning of the training
self._status = ServerStatus.WAITING
metadata_stream.state = 'waiting'
metadata_stream.state = ServerStatus.WAITING.value
metadata = metadata_stream.get()
res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
res.update(metadata)
@@ -213,7 +213,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
event = self._deal_with_left_continue_step(left_step_count)
else:
event = self._deal_with_left_continue_node(node_name)
log.debug("Send old RunCMD. Clean watchpoint hit.")
log.debug("Send old RunCMD.")
return event

def _deal_with_left_continue_step(self, left_step_count):
@@ -270,7 +270,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value
self._cache_store.put_data({'metadata': {'state': 'waiting'}})
event = None
while event is None and self._status in [ServerStatus.MISMATCH, ServerStatus.WAITING]:
while event is None and self._status not in [ServerStatus.RUNNING, ServerStatus.PENDING]:
log.debug("Wait for %s-th command", self._pos)
event = self._get_next_command()
return event
@@ -280,12 +280,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._pos, event = self._cache_store.get_command(self._pos)
if event is None:
return event
# deal with command
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
if isinstance(event, dict):
event = self._deal_with_view_cmd(event)
elif event.HasField('run_cmd'):
event = self._deal_with_run_cmd(event)
self._cache_store.put_data(metadata_stream.get())
elif event.HasField('exit'):
self._cache_store.clean()
self._cache_store.put_data(metadata_stream.get())
log.debug("Clean cache for exit cmd.")
else:
self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd)
@@ -319,13 +323,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):

def _deal_with_run_cmd(self, event):
"""Deal with run cmd."""
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
run_cmd = event.run_cmd
# receive step command
if run_cmd.run_level == 'step':
if run_cmd.run_level == RunLevel.STEP.value:
# receive pause cmd
if not run_cmd.run_steps:
log.debug("Pause training and wait for next command.")
self._old_run_cmd.clear()
# update metadata state from sending to waiting
metadata_stream.state = ServerStatus.WAITING.value
return None
# receive step cmd
left_steps = run_cmd.run_steps - 1
@@ -338,8 +345,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
# clean watchpoint hit cache
if run_cmd.run_level == RunLevel.RECHECK.value:
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
log.debug("Receive RunCMD. Clean watchpoint hit cache.")

log.debug("Receive RunCMD. Clean watchpoint hit cache.")
# update metadata state from sending to running
metadata_stream.state = ServerStatus.RUNNING.value
return event

@debugger_wrap
@@ -364,7 +372,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
ms_version, mindinsight.__version__)
self._status = ServerStatus.MISMATCH
reply.version_matched = False
metadata_stream.state = 'mismatch'
metadata_stream.state = ServerStatus.MISMATCH.value
else:
log.info("version is matched.")
reply.version_matched = True


+ 2
- 0
mindinsight/debugger/stream_handler/event_handler.py View File

@@ -43,6 +43,8 @@ class EventHandler(StreamHandlerBase):
def has_pos(self, pos):
"""Get the event according to pos."""
cur_flag, cur_idx = self._parse_pos(pos)
if cur_flag not in [self._cur_flag, self._prev_flag]:
cur_flag, cur_idx = self._cur_flag, 0
event = self._event_cache[cur_idx]
if event is not None:
if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \


+ 5
- 4
mindinsight/debugger/stream_operator/training_control_operator.py View File

@@ -95,7 +95,7 @@ class TrainingControlOperator:
raise DebuggerContinueError(
"MindSpore is not ready to run or is running currently."
)
metadata_stream.state = ServerStatus.RUNNING.value
metadata_stream.state = ServerStatus.SENDING.value
try:
self._validate_continue_params(params)
event = self._construct_run_event(params)
@@ -214,9 +214,10 @@ class TrainingControlOperator:
if metadata_stream.state != ServerStatus.RUNNING.value:
log.error("The MindSpore is not running.")
raise DebuggerPauseError("The MindSpore is not running.")
metadata_stream.state = 'waiting'
metadata_stream.state = ServerStatus.SENDING.value
event = get_ack_reply()
event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
self._cache_store.clean_command()
self._cache_store.put_command(event)
metadata_stream.enable_recheck = False
log.debug("Send the Pause command")
@@ -230,7 +231,7 @@ class TrainingControlOperator:
dict, metadata info.
"""
metadata_stream = self._metadata_stream
metadata_stream.state = 'pending'
metadata_stream.state = ServerStatus.SENDING.value
self._cache_store.clean_data()
self._cache_store.clean_command()
event = get_ack_reply()
@@ -252,7 +253,7 @@ class TrainingControlOperator:
if not metadata_stream.enable_recheck:
log.error("Recheck is not available.")
raise DebuggerRecheckError("Recheck is not available.")
metadata_stream.state = ServerStatus.RUNNING.value
metadata_stream.state = ServerStatus.SENDING.value
metadata_stream.enable_recheck = False
# send updated watchpoint and recheck command
try:


+ 6
- 5
tests/st/func/debugger/test_restful_api.py View File

@@ -307,12 +307,13 @@ class TestAscendDebugger:
body_data = {'mode': 'continue',
'steps': -1}
res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
# send pause command
check_state(app_client, 'running')
url = 'control'
body_data = {'mode': 'pause'}
res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'waiting', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
send_terminate_cmd(app_client)

@pytest.mark.level0
@@ -413,7 +414,7 @@ class TestGPUDebugger:
'level': 'node',
'name': 'Default/TransData-op99'}
res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
# get metadata
check_state(app_client)
url = 'retrieve'
@@ -617,7 +618,7 @@ class TestMultiGraphDebugger:
body_data = {'mode': 'continue'}
body_data.update(params)
res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
# get metadata
check_state(app_client)
url = 'retrieve'
@@ -669,7 +670,7 @@ def create_watchpoint_and_wait(app_client):
body_data = {'mode': 'continue',
'steps': 2}
res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
# wait for server has received watchpoint hit
check_state(app_client)



+ 3
- 3
tests/ut/debugger/stream_operator/test_training_control_operator.py View File

@@ -51,9 +51,9 @@ class TestTrainingControlOperator:
self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name')

@pytest.mark.parametrize('mode, cur_state, state', [
('continue', 'waiting', 'running'),
('pause', 'running', 'waiting'),
('terminate', 'waiting', 'pending')])
('continue', 'waiting', 'sending'),
('pause', 'running', 'sending'),
('terminate', 'waiting', 'sending')])
def test_control(self, mode, cur_state, state):
"""Test control request."""
with mock.patch.object(MetadataHandler, 'state', cur_state):


Loading…
Cancel
Save