Browse Source

add sending state in debugger

tags/v1.1.0
yelihua 5 years ago
parent
commit
cb816c3efd
6 changed files with 34 additions and 21 deletions
  1. +2
    -1
      mindinsight/debugger/common/utils.py
  2. +16
    -8
      mindinsight/debugger/debugger_grpc_server.py
  3. +2
    -0
      mindinsight/debugger/stream_handler/event_handler.py
  4. +5
    -4
      mindinsight/debugger/stream_operator/training_control_operator.py
  5. +6
    -5
      tests/st/func/debugger/test_restful_api.py
  6. +3
    -3
      tests/ut/debugger/stream_operator/test_training_control_operator.py

+ 2
- 1
mindinsight/debugger/common/utils.py View File

@@ -56,7 +56,8 @@ class ServerStatus(enum.Enum):
RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph
WAITING = 'waiting' # the client session is ready WAITING = 'waiting' # the client session is ready
RUNNING = 'running' # the client session is running a script RUNNING = 'running' # the client session is running a script
MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched
MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched
SENDING = 'sending' # the request is in cache but not be sent to client




@enum.unique @enum.unique


+ 16
- 8
mindinsight/debugger/debugger_grpc_server.py View File

@@ -104,7 +104,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):


# check if version is mismatch, if mismatch, send mismatch info to UI # check if version is mismatch, if mismatch, send mismatch info to UI
if self._status == ServerStatus.MISMATCH: if self._status == ServerStatus.MISMATCH:
log.warning("Version of Mindspore and Mindinsight re unmatched,"
log.warning("Version of MindSpore and MindInsight are unmatched,"
"waiting for user to terminate the script.") "waiting for user to terminate the script.")
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
# put metadata into data queue # put metadata into data queue
@@ -140,7 +140,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store.clean_command() self._cache_store.clean_command()
# receive graph in the beginning of the training # receive graph in the beginning of the training
self._status = ServerStatus.WAITING self._status = ServerStatus.WAITING
metadata_stream.state = 'waiting'
metadata_stream.state = ServerStatus.WAITING.value
metadata = metadata_stream.get() metadata = metadata_stream.get()
res = self._cache_store.get_stream_handler(Streams.GRAPH).get() res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
res.update(metadata) res.update(metadata)
@@ -213,7 +213,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
event = self._deal_with_left_continue_step(left_step_count) event = self._deal_with_left_continue_step(left_step_count)
else: else:
event = self._deal_with_left_continue_node(node_name) event = self._deal_with_left_continue_node(node_name)
log.debug("Send old RunCMD. Clean watchpoint hit.")
log.debug("Send old RunCMD.")
return event return event


def _deal_with_left_continue_step(self, left_step_count): def _deal_with_left_continue_step(self, left_step_count):
@@ -270,7 +270,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value
self._cache_store.put_data({'metadata': {'state': 'waiting'}}) self._cache_store.put_data({'metadata': {'state': 'waiting'}})
event = None event = None
while event is None and self._status in [ServerStatus.MISMATCH, ServerStatus.WAITING]:
while event is None and self._status not in [ServerStatus.RUNNING, ServerStatus.PENDING]:
log.debug("Wait for %s-th command", self._pos) log.debug("Wait for %s-th command", self._pos)
event = self._get_next_command() event = self._get_next_command()
return event return event
@@ -280,12 +280,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._pos, event = self._cache_store.get_command(self._pos) self._pos, event = self._cache_store.get_command(self._pos)
if event is None: if event is None:
return event return event
# deal with command
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
if isinstance(event, dict): if isinstance(event, dict):
event = self._deal_with_view_cmd(event) event = self._deal_with_view_cmd(event)
elif event.HasField('run_cmd'): elif event.HasField('run_cmd'):
event = self._deal_with_run_cmd(event) event = self._deal_with_run_cmd(event)
self._cache_store.put_data(metadata_stream.get())
elif event.HasField('exit'): elif event.HasField('exit'):
self._cache_store.clean() self._cache_store.clean()
self._cache_store.put_data(metadata_stream.get())
log.debug("Clean cache for exit cmd.") log.debug("Clean cache for exit cmd.")
else: else:
self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd) self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd)
@@ -319,13 +323,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):


def _deal_with_run_cmd(self, event): def _deal_with_run_cmd(self, event):
"""Deal with run cmd.""" """Deal with run cmd."""
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
run_cmd = event.run_cmd run_cmd = event.run_cmd
# receive step command # receive step command
if run_cmd.run_level == 'step':
if run_cmd.run_level == RunLevel.STEP.value:
# receive pause cmd # receive pause cmd
if not run_cmd.run_steps: if not run_cmd.run_steps:
log.debug("Pause training and wait for next command.") log.debug("Pause training and wait for next command.")
self._old_run_cmd.clear() self._old_run_cmd.clear()
# update metadata state from sending to waiting
metadata_stream.state = ServerStatus.WAITING.value
return None return None
# receive step cmd # receive step cmd
left_steps = run_cmd.run_steps - 1 left_steps = run_cmd.run_steps - 1
@@ -338,8 +345,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
# clean watchpoint hit cache # clean watchpoint hit cache
if run_cmd.run_level == RunLevel.RECHECK.value: if run_cmd.run_level == RunLevel.RECHECK.value:
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
log.debug("Receive RunCMD. Clean watchpoint hit cache.")

log.debug("Receive RunCMD. Clean watchpoint hit cache.")
# update metadata state from sending to running
metadata_stream.state = ServerStatus.RUNNING.value
return event return event


@debugger_wrap @debugger_wrap
@@ -364,7 +372,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
ms_version, mindinsight.__version__) ms_version, mindinsight.__version__)
self._status = ServerStatus.MISMATCH self._status = ServerStatus.MISMATCH
reply.version_matched = False reply.version_matched = False
metadata_stream.state = 'mismatch'
metadata_stream.state = ServerStatus.MISMATCH.value
else: else:
log.info("version is matched.") log.info("version is matched.")
reply.version_matched = True reply.version_matched = True


+ 2
- 0
mindinsight/debugger/stream_handler/event_handler.py View File

@@ -43,6 +43,8 @@ class EventHandler(StreamHandlerBase):
def has_pos(self, pos): def has_pos(self, pos):
"""Get the event according to pos.""" """Get the event according to pos."""
cur_flag, cur_idx = self._parse_pos(pos) cur_flag, cur_idx = self._parse_pos(pos)
if cur_flag not in [self._cur_flag, self._prev_flag]:
cur_flag, cur_idx = self._cur_flag, 0
event = self._event_cache[cur_idx] event = self._event_cache[cur_idx]
if event is not None: if event is not None:
if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \ if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \


+ 5
- 4
mindinsight/debugger/stream_operator/training_control_operator.py View File

@@ -95,7 +95,7 @@ class TrainingControlOperator:
raise DebuggerContinueError( raise DebuggerContinueError(
"MindSpore is not ready to run or is running currently." "MindSpore is not ready to run or is running currently."
) )
metadata_stream.state = ServerStatus.RUNNING.value
metadata_stream.state = ServerStatus.SENDING.value
try: try:
self._validate_continue_params(params) self._validate_continue_params(params)
event = self._construct_run_event(params) event = self._construct_run_event(params)
@@ -214,9 +214,10 @@ class TrainingControlOperator:
if metadata_stream.state != ServerStatus.RUNNING.value: if metadata_stream.state != ServerStatus.RUNNING.value:
log.error("The MindSpore is not running.") log.error("The MindSpore is not running.")
raise DebuggerPauseError("The MindSpore is not running.") raise DebuggerPauseError("The MindSpore is not running.")
metadata_stream.state = 'waiting'
metadata_stream.state = ServerStatus.SENDING.value
event = get_ack_reply() event = get_ack_reply()
event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
self._cache_store.clean_command()
self._cache_store.put_command(event) self._cache_store.put_command(event)
metadata_stream.enable_recheck = False metadata_stream.enable_recheck = False
log.debug("Send the Pause command") log.debug("Send the Pause command")
@@ -230,7 +231,7 @@ class TrainingControlOperator:
dict, metadata info. dict, metadata info.
""" """
metadata_stream = self._metadata_stream metadata_stream = self._metadata_stream
metadata_stream.state = 'pending'
metadata_stream.state = ServerStatus.SENDING.value
self._cache_store.clean_data() self._cache_store.clean_data()
self._cache_store.clean_command() self._cache_store.clean_command()
event = get_ack_reply() event = get_ack_reply()
@@ -252,7 +253,7 @@ class TrainingControlOperator:
if not metadata_stream.enable_recheck: if not metadata_stream.enable_recheck:
log.error("Recheck is not available.") log.error("Recheck is not available.")
raise DebuggerRecheckError("Recheck is not available.") raise DebuggerRecheckError("Recheck is not available.")
metadata_stream.state = ServerStatus.RUNNING.value
metadata_stream.state = ServerStatus.SENDING.value
metadata_stream.enable_recheck = False metadata_stream.enable_recheck = False
# send updated watchpoint and recheck command # send updated watchpoint and recheck command
try: try:


+ 6
- 5
tests/st/func/debugger/test_restful_api.py View File

@@ -307,12 +307,13 @@ class TestAscendDebugger:
body_data = {'mode': 'continue', body_data = {'mode': 'continue',
'steps': -1} 'steps': -1}
res = get_request_result(app_client, url, body_data) res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
# send pause command # send pause command
check_state(app_client, 'running')
url = 'control' url = 'control'
body_data = {'mode': 'pause'} body_data = {'mode': 'pause'}
res = get_request_result(app_client, url, body_data) res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'waiting', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
send_terminate_cmd(app_client) send_terminate_cmd(app_client)


@pytest.mark.level0 @pytest.mark.level0
@@ -413,7 +414,7 @@ class TestGPUDebugger:
'level': 'node', 'level': 'node',
'name': 'Default/TransData-op99'} 'name': 'Default/TransData-op99'}
res = get_request_result(app_client, url, body_data) res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
# get metadata # get metadata
check_state(app_client) check_state(app_client)
url = 'retrieve' url = 'retrieve'
@@ -617,7 +618,7 @@ class TestMultiGraphDebugger:
body_data = {'mode': 'continue'} body_data = {'mode': 'continue'}
body_data.update(params) body_data.update(params)
res = get_request_result(app_client, url, body_data) res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
# get metadata # get metadata
check_state(app_client) check_state(app_client)
url = 'retrieve' url = 'retrieve'
@@ -669,7 +670,7 @@ def create_watchpoint_and_wait(app_client):
body_data = {'mode': 'continue', body_data = {'mode': 'continue',
'steps': 2} 'steps': 2}
res = get_request_result(app_client, url, body_data) res = get_request_result(app_client, url, body_data)
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
# wait for server has received watchpoint hit # wait for server has received watchpoint hit
check_state(app_client) check_state(app_client)




+ 3
- 3
tests/ut/debugger/stream_operator/test_training_control_operator.py View File

@@ -51,9 +51,9 @@ class TestTrainingControlOperator:
self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name')


@pytest.mark.parametrize('mode, cur_state, state', [ @pytest.mark.parametrize('mode, cur_state, state', [
('continue', 'waiting', 'running'),
('pause', 'running', 'waiting'),
('terminate', 'waiting', 'pending')])
('continue', 'waiting', 'sending'),
('pause', 'running', 'sending'),
('terminate', 'waiting', 'sending')])
def test_control(self, mode, cur_state, state): def test_control(self, mode, cur_state, state):
"""Test control request.""" """Test control request."""
with mock.patch.object(MetadataHandler, 'state', cur_state): with mock.patch.object(MetadataHandler, 'state', cur_state):


Loading…
Cancel
Save