From: @yelihua Reviewed-by: @wangyue01,@lilongfei15 Signed-off-by: @lilongfei15tags/v1.1.0
| @@ -56,7 +56,8 @@ class ServerStatus(enum.Enum): | |||
| RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph | |||
| WAITING = 'waiting' # the client session is ready | |||
| RUNNING = 'running' # the client session is running a script | |||
| MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched | |||
| MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched | |||
| SENDING = 'sending' # the request is in cache but not be sent to client | |||
| @enum.unique | |||
| @@ -104,7 +104,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| # check if version is mismatch, if mismatch, send mismatch info to UI | |||
| if self._status == ServerStatus.MISMATCH: | |||
| log.warning("Version of Mindspore and Mindinsight re unmatched," | |||
| log.warning("Version of MindSpore and MindInsight are unmatched," | |||
| "waiting for user to terminate the script.") | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| # put metadata into data queue | |||
| @@ -140,7 +140,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._cache_store.clean_command() | |||
| # receive graph in the beginning of the training | |||
| self._status = ServerStatus.WAITING | |||
| metadata_stream.state = 'waiting' | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| metadata = metadata_stream.get() | |||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get() | |||
| res.update(metadata) | |||
| @@ -213,7 +213,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| event = self._deal_with_left_continue_step(left_step_count) | |||
| else: | |||
| event = self._deal_with_left_continue_node(node_name) | |||
| log.debug("Send old RunCMD. Clean watchpoint hit.") | |||
| log.debug("Send old RunCMD.") | |||
| return event | |||
| def _deal_with_left_continue_step(self, left_step_count): | |||
| @@ -270,7 +270,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value | |||
| self._cache_store.put_data({'metadata': {'state': 'waiting'}}) | |||
| event = None | |||
| while event is None and self._status in [ServerStatus.MISMATCH, ServerStatus.WAITING]: | |||
| while event is None and self._status not in [ServerStatus.RUNNING, ServerStatus.PENDING]: | |||
| log.debug("Wait for %s-th command", self._pos) | |||
| event = self._get_next_command() | |||
| return event | |||
| @@ -280,12 +280,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._pos, event = self._cache_store.get_command(self._pos) | |||
| if event is None: | |||
| return event | |||
| # deal with command | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| if isinstance(event, dict): | |||
| event = self._deal_with_view_cmd(event) | |||
| elif event.HasField('run_cmd'): | |||
| event = self._deal_with_run_cmd(event) | |||
| self._cache_store.put_data(metadata_stream.get()) | |||
| elif event.HasField('exit'): | |||
| self._cache_store.clean() | |||
| self._cache_store.put_data(metadata_stream.get()) | |||
| log.debug("Clean cache for exit cmd.") | |||
| else: | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd) | |||
| @@ -319,13 +323,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def _deal_with_run_cmd(self, event): | |||
| """Deal with run cmd.""" | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| run_cmd = event.run_cmd | |||
| # receive step command | |||
| if run_cmd.run_level == 'step': | |||
| if run_cmd.run_level == RunLevel.STEP.value: | |||
| # receive pause cmd | |||
| if not run_cmd.run_steps: | |||
| log.debug("Pause training and wait for next command.") | |||
| self._old_run_cmd.clear() | |||
| # update metadata state from sending to waiting | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| return None | |||
| # receive step cmd | |||
| left_steps = run_cmd.run_steps - 1 | |||
| @@ -338,8 +345,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| # clean watchpoint hit cache | |||
| if run_cmd.run_level == RunLevel.RECHECK.value: | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||
| log.debug("Receive RunCMD. Clean watchpoint hit cache.") | |||
| log.debug("Receive RunCMD. Clean watchpoint hit cache.") | |||
| # update metadata state from sending to running | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| return event | |||
| @debugger_wrap | |||
| @@ -364,7 +372,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| ms_version, mindinsight.__version__) | |||
| self._status = ServerStatus.MISMATCH | |||
| reply.version_matched = False | |||
| metadata_stream.state = 'mismatch' | |||
| metadata_stream.state = ServerStatus.MISMATCH.value | |||
| else: | |||
| log.info("version is matched.") | |||
| reply.version_matched = True | |||
| @@ -43,6 +43,8 @@ class EventHandler(StreamHandlerBase): | |||
| def has_pos(self, pos): | |||
| """Get the event according to pos.""" | |||
| cur_flag, cur_idx = self._parse_pos(pos) | |||
| if cur_flag not in [self._cur_flag, self._prev_flag]: | |||
| cur_flag, cur_idx = self._cur_flag, 0 | |||
| event = self._event_cache[cur_idx] | |||
| if event is not None: | |||
| if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \ | |||
| @@ -95,7 +95,7 @@ class TrainingControlOperator: | |||
| raise DebuggerContinueError( | |||
| "MindSpore is not ready to run or is running currently." | |||
| ) | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| metadata_stream.state = ServerStatus.SENDING.value | |||
| try: | |||
| self._validate_continue_params(params) | |||
| event = self._construct_run_event(params) | |||
| @@ -214,9 +214,10 @@ class TrainingControlOperator: | |||
| if metadata_stream.state != ServerStatus.RUNNING.value: | |||
| log.error("The MindSpore is not running.") | |||
| raise DebuggerPauseError("The MindSpore is not running.") | |||
| metadata_stream.state = 'waiting' | |||
| metadata_stream.state = ServerStatus.SENDING.value | |||
| event = get_ack_reply() | |||
| event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) | |||
| self._cache_store.clean_command() | |||
| self._cache_store.put_command(event) | |||
| metadata_stream.enable_recheck = False | |||
| log.debug("Send the Pause command") | |||
| @@ -230,7 +231,7 @@ class TrainingControlOperator: | |||
| dict, metadata info. | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| metadata_stream.state = 'pending' | |||
| metadata_stream.state = ServerStatus.SENDING.value | |||
| self._cache_store.clean_data() | |||
| self._cache_store.clean_command() | |||
| event = get_ack_reply() | |||
| @@ -252,7 +253,7 @@ class TrainingControlOperator: | |||
| if not metadata_stream.enable_recheck: | |||
| log.error("Recheck is not available.") | |||
| raise DebuggerRecheckError("Recheck is not available.") | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| metadata_stream.state = ServerStatus.SENDING.value | |||
| metadata_stream.enable_recheck = False | |||
| # send updated watchpoint and recheck command | |||
| try: | |||
| @@ -307,12 +307,13 @@ class TestAscendDebugger: | |||
| body_data = {'mode': 'continue', | |||
| 'steps': -1} | |||
| res = get_request_result(app_client, url, body_data) | |||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||
| # send pause command | |||
| check_state(app_client, 'running') | |||
| url = 'control' | |||
| body_data = {'mode': 'pause'} | |||
| res = get_request_result(app_client, url, body_data) | |||
| assert res == {'metadata': {'state': 'waiting', 'enable_recheck': False}} | |||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||
| send_terminate_cmd(app_client) | |||
| @pytest.mark.level0 | |||
| @@ -413,7 +414,7 @@ class TestGPUDebugger: | |||
| 'level': 'node', | |||
| 'name': 'Default/TransData-op99'} | |||
| res = get_request_result(app_client, url, body_data) | |||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||
| # get metadata | |||
| check_state(app_client) | |||
| url = 'retrieve' | |||
| @@ -617,7 +618,7 @@ class TestMultiGraphDebugger: | |||
| body_data = {'mode': 'continue'} | |||
| body_data.update(params) | |||
| res = get_request_result(app_client, url, body_data) | |||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||
| # get metadata | |||
| check_state(app_client) | |||
| url = 'retrieve' | |||
| @@ -669,7 +670,7 @@ def create_watchpoint_and_wait(app_client): | |||
| body_data = {'mode': 'continue', | |||
| 'steps': 2} | |||
| res = get_request_result(app_client, url, body_data) | |||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||
| # wait for server has received watchpoint hit | |||
| check_state(app_client) | |||
| @@ -51,9 +51,9 @@ class TestTrainingControlOperator: | |||
| self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') | |||
| @pytest.mark.parametrize('mode, cur_state, state', [ | |||
| ('continue', 'waiting', 'running'), | |||
| ('pause', 'running', 'waiting'), | |||
| ('terminate', 'waiting', 'pending')]) | |||
| ('continue', 'waiting', 'sending'), | |||
| ('pause', 'running', 'sending'), | |||
| ('terminate', 'waiting', 'sending')]) | |||
| def test_control(self, mode, cur_state, state): | |||
| """Test control request.""" | |||
| with mock.patch.object(MetadataHandler, 'state', cur_state): | |||