diff --git a/mindinsight/debugger/common/utils.py b/mindinsight/debugger/common/utils.py index 8c21ac62..24c313b5 100644 --- a/mindinsight/debugger/common/utils.py +++ b/mindinsight/debugger/common/utils.py @@ -56,7 +56,8 @@ class ServerStatus(enum.Enum): RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph WAITING = 'waiting' # the client session is ready RUNNING = 'running' # the client session is running a script - MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched + MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched + SENDING = 'sending' # the request is in cache but not be sent to client @enum.unique diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_grpc_server.py index 5f619c2e..6fd2654c 100644 --- a/mindinsight/debugger/debugger_grpc_server.py +++ b/mindinsight/debugger/debugger_grpc_server.py @@ -104,7 +104,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): # check if version is mismatch, if mismatch, send mismatch info to UI if self._status == ServerStatus.MISMATCH: - log.warning("Version of Mindspore and Mindinsight re unmatched," + log.warning("Version of MindSpore and MindInsight are unmatched," "waiting for user to terminate the script.") metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) # put metadata into data queue @@ -140,7 +140,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): self._cache_store.clean_command() # receive graph in the beginning of the training self._status = ServerStatus.WAITING - metadata_stream.state = 'waiting' + metadata_stream.state = ServerStatus.WAITING.value metadata = metadata_stream.get() res = self._cache_store.get_stream_handler(Streams.GRAPH).get() res.update(metadata) @@ -213,7 +213,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): event = self._deal_with_left_continue_step(left_step_count) else: event = self._deal_with_left_continue_node(node_name) - log.debug("Send old RunCMD. Clean watchpoint hit.") + log.debug("Send old RunCMD.") return event def _deal_with_left_continue_step(self, left_step_count): @@ -270,7 +270,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value self._cache_store.put_data({'metadata': {'state': 'waiting'}}) event = None - while event is None and self._status in [ServerStatus.MISMATCH, ServerStatus.WAITING]: + while event is None and self._status not in [ServerStatus.RUNNING, ServerStatus.PENDING]: log.debug("Wait for %s-th command", self._pos) event = self._get_next_command() return event @@ -280,12 +280,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): self._pos, event = self._cache_store.get_command(self._pos) if event is None: return event + # deal with command + metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) if isinstance(event, dict): event = self._deal_with_view_cmd(event) elif event.HasField('run_cmd'): event = self._deal_with_run_cmd(event) + self._cache_store.put_data(metadata_stream.get()) elif event.HasField('exit'): self._cache_store.clean() + self._cache_store.put_data(metadata_stream.get()) log.debug("Clean cache for exit cmd.") else: self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd) @@ -319,13 +323,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): def _deal_with_run_cmd(self, event): """Deal with run cmd.""" + metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) run_cmd = event.run_cmd # receive step command - if run_cmd.run_level == 'step': + if run_cmd.run_level == RunLevel.STEP.value: # receive pause cmd if not run_cmd.run_steps: log.debug("Pause training and wait for next command.") self._old_run_cmd.clear() + # update metadata state from sending to waiting + metadata_stream.state = ServerStatus.WAITING.value return None # receive step cmd left_steps = run_cmd.run_steps - 1 @@ -338,8 +345,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): # clean watchpoint hit cache if run_cmd.run_level == RunLevel.RECHECK.value: self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() - log.debug("Receive RunCMD. Clean watchpoint hit cache.") - + log.debug("Receive RunCMD. Clean watchpoint hit cache.") + # update metadata state from sending to running + metadata_stream.state = ServerStatus.RUNNING.value return event @debugger_wrap @@ -364,7 +372,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): ms_version, mindinsight.__version__) self._status = ServerStatus.MISMATCH reply.version_matched = False - metadata_stream.state = 'mismatch' + metadata_stream.state = ServerStatus.MISMATCH.value else: log.info("version is matched.") reply.version_matched = True diff --git a/mindinsight/debugger/stream_handler/event_handler.py b/mindinsight/debugger/stream_handler/event_handler.py index 2338daf9..d9419a51 100644 --- a/mindinsight/debugger/stream_handler/event_handler.py +++ b/mindinsight/debugger/stream_handler/event_handler.py @@ -43,6 +43,8 @@ class EventHandler(StreamHandlerBase): def has_pos(self, pos): """Get the event according to pos.""" cur_flag, cur_idx = self._parse_pos(pos) + if cur_flag not in [self._cur_flag, self._prev_flag]: + cur_flag, cur_idx = self._cur_flag, 0 event = self._event_cache[cur_idx] if event is not None: if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \ diff --git a/mindinsight/debugger/stream_operator/training_control_operator.py b/mindinsight/debugger/stream_operator/training_control_operator.py index 76da792e..913b185c 100644 --- a/mindinsight/debugger/stream_operator/training_control_operator.py +++ b/mindinsight/debugger/stream_operator/training_control_operator.py @@ -95,7 +95,7 @@ class TrainingControlOperator: raise DebuggerContinueError( "MindSpore is not ready to run or is running currently." ) - metadata_stream.state = ServerStatus.RUNNING.value + metadata_stream.state = ServerStatus.SENDING.value try: self._validate_continue_params(params) event = self._construct_run_event(params) @@ -214,9 +214,10 @@ class TrainingControlOperator: if metadata_stream.state != ServerStatus.RUNNING.value: log.error("The MindSpore is not running.") raise DebuggerPauseError("The MindSpore is not running.") - metadata_stream.state = 'waiting' + metadata_stream.state = ServerStatus.SENDING.value event = get_ack_reply() event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) + self._cache_store.clean_command() self._cache_store.put_command(event) metadata_stream.enable_recheck = False log.debug("Send the Pause command") @@ -230,7 +231,7 @@ class TrainingControlOperator: dict, metadata info. """ metadata_stream = self._metadata_stream - metadata_stream.state = 'pending' + metadata_stream.state = ServerStatus.SENDING.value self._cache_store.clean_data() self._cache_store.clean_command() event = get_ack_reply() @@ -252,7 +253,7 @@ class TrainingControlOperator: if not metadata_stream.enable_recheck: log.error("Recheck is not available.") raise DebuggerRecheckError("Recheck is not available.") - metadata_stream.state = ServerStatus.RUNNING.value + metadata_stream.state = ServerStatus.SENDING.value metadata_stream.enable_recheck = False # send updated watchpoint and recheck command try: diff --git a/tests/st/func/debugger/test_restful_api.py b/tests/st/func/debugger/test_restful_api.py index 71958b97..2701dc86 100644 --- a/tests/st/func/debugger/test_restful_api.py +++ b/tests/st/func/debugger/test_restful_api.py @@ -307,12 +307,13 @@ class TestAscendDebugger: body_data = {'mode': 'continue', 'steps': -1} res = get_request_result(app_client, url, body_data) - assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} + assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} # send pause command + check_state(app_client, 'running') url = 'control' body_data = {'mode': 'pause'} res = get_request_result(app_client, url, body_data) - assert res == {'metadata': {'state': 'waiting', 'enable_recheck': False}} + assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} send_terminate_cmd(app_client) @pytest.mark.level0 @@ -413,7 +414,7 @@ class TestGPUDebugger: 'level': 'node', 'name': 'Default/TransData-op99'} res = get_request_result(app_client, url, body_data) - assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} + assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} # get metadata check_state(app_client) url = 'retrieve' @@ -617,7 +618,7 @@ class TestMultiGraphDebugger: body_data = {'mode': 'continue'} body_data.update(params) res = get_request_result(app_client, url, body_data) - assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} + assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} # get metadata check_state(app_client) url = 'retrieve' @@ -669,7 +670,7 @@ def create_watchpoint_and_wait(app_client): body_data = {'mode': 'continue', 'steps': 2} res = get_request_result(app_client, url, body_data) - assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} + assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} # wait for server has received watchpoint hit check_state(app_client) diff --git a/tests/ut/debugger/stream_operator/test_training_control_operator.py b/tests/ut/debugger/stream_operator/test_training_control_operator.py index 43de2254..1d1d5ef6 100644 --- a/tests/ut/debugger/stream_operator/test_training_control_operator.py +++ b/tests/ut/debugger/stream_operator/test_training_control_operator.py @@ -51,9 +51,9 @@ class TestTrainingControlOperator: self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') @pytest.mark.parametrize('mode, cur_state, state', [ - ('continue', 'waiting', 'running'), - ('pause', 'running', 'waiting'), - ('terminate', 'waiting', 'pending')]) + ('continue', 'waiting', 'sending'), + ('pause', 'running', 'sending'), + ('terminate', 'waiting', 'sending')]) def test_control(self, mode, cur_state, state): """Test control request.""" with mock.patch.object(MetadataHandler, 'state', cur_state):