| @@ -56,7 +56,8 @@ class ServerStatus(enum.Enum): | |||||
| RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph | RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph | ||||
| WAITING = 'waiting' # the client session is ready | WAITING = 'waiting' # the client session is ready | ||||
| RUNNING = 'running' # the client session is running a script | RUNNING = 'running' # the client session is running a script | ||||
| MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched | |||||
| MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched | |||||
| SENDING = 'sending' # the request is in cache but not be sent to client | |||||
| @enum.unique | @enum.unique | ||||
| @@ -104,7 +104,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| # check if version is mismatch, if mismatch, send mismatch info to UI | # check if version is mismatch, if mismatch, send mismatch info to UI | ||||
| if self._status == ServerStatus.MISMATCH: | if self._status == ServerStatus.MISMATCH: | ||||
| log.warning("Version of Mindspore and Mindinsight re unmatched," | |||||
| log.warning("Version of MindSpore and MindInsight are unmatched," | |||||
| "waiting for user to terminate the script.") | "waiting for user to terminate the script.") | ||||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | ||||
| # put metadata into data queue | # put metadata into data queue | ||||
| @@ -140,7 +140,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| self._cache_store.clean_command() | self._cache_store.clean_command() | ||||
| # receive graph in the beginning of the training | # receive graph in the beginning of the training | ||||
| self._status = ServerStatus.WAITING | self._status = ServerStatus.WAITING | ||||
| metadata_stream.state = 'waiting' | |||||
| metadata_stream.state = ServerStatus.WAITING.value | |||||
| metadata = metadata_stream.get() | metadata = metadata_stream.get() | ||||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get() | res = self._cache_store.get_stream_handler(Streams.GRAPH).get() | ||||
| res.update(metadata) | res.update(metadata) | ||||
| @@ -213,7 +213,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| event = self._deal_with_left_continue_step(left_step_count) | event = self._deal_with_left_continue_step(left_step_count) | ||||
| else: | else: | ||||
| event = self._deal_with_left_continue_node(node_name) | event = self._deal_with_left_continue_node(node_name) | ||||
| log.debug("Send old RunCMD. Clean watchpoint hit.") | |||||
| log.debug("Send old RunCMD.") | |||||
| return event | return event | ||||
| def _deal_with_left_continue_step(self, left_step_count): | def _deal_with_left_continue_step(self, left_step_count): | ||||
| @@ -270,7 +270,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value | self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value | ||||
| self._cache_store.put_data({'metadata': {'state': 'waiting'}}) | self._cache_store.put_data({'metadata': {'state': 'waiting'}}) | ||||
| event = None | event = None | ||||
| while event is None and self._status in [ServerStatus.MISMATCH, ServerStatus.WAITING]: | |||||
| while event is None and self._status not in [ServerStatus.RUNNING, ServerStatus.PENDING]: | |||||
| log.debug("Wait for %s-th command", self._pos) | log.debug("Wait for %s-th command", self._pos) | ||||
| event = self._get_next_command() | event = self._get_next_command() | ||||
| return event | return event | ||||
| @@ -280,12 +280,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| self._pos, event = self._cache_store.get_command(self._pos) | self._pos, event = self._cache_store.get_command(self._pos) | ||||
| if event is None: | if event is None: | ||||
| return event | return event | ||||
| # deal with command | |||||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||||
| if isinstance(event, dict): | if isinstance(event, dict): | ||||
| event = self._deal_with_view_cmd(event) | event = self._deal_with_view_cmd(event) | ||||
| elif event.HasField('run_cmd'): | elif event.HasField('run_cmd'): | ||||
| event = self._deal_with_run_cmd(event) | event = self._deal_with_run_cmd(event) | ||||
| self._cache_store.put_data(metadata_stream.get()) | |||||
| elif event.HasField('exit'): | elif event.HasField('exit'): | ||||
| self._cache_store.clean() | self._cache_store.clean() | ||||
| self._cache_store.put_data(metadata_stream.get()) | |||||
| log.debug("Clean cache for exit cmd.") | log.debug("Clean cache for exit cmd.") | ||||
| else: | else: | ||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd) | self._cache_store.get_stream_handler(Streams.WATCHPOINT).clean_cache_set_cmd(event.set_cmd) | ||||
| @@ -319,13 +323,16 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| def _deal_with_run_cmd(self, event): | def _deal_with_run_cmd(self, event): | ||||
| """Deal with run cmd.""" | """Deal with run cmd.""" | ||||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||||
| run_cmd = event.run_cmd | run_cmd = event.run_cmd | ||||
| # receive step command | # receive step command | ||||
| if run_cmd.run_level == 'step': | |||||
| if run_cmd.run_level == RunLevel.STEP.value: | |||||
| # receive pause cmd | # receive pause cmd | ||||
| if not run_cmd.run_steps: | if not run_cmd.run_steps: | ||||
| log.debug("Pause training and wait for next command.") | log.debug("Pause training and wait for next command.") | ||||
| self._old_run_cmd.clear() | self._old_run_cmd.clear() | ||||
| # update metadata state from sending to waiting | |||||
| metadata_stream.state = ServerStatus.WAITING.value | |||||
| return None | return None | ||||
| # receive step cmd | # receive step cmd | ||||
| left_steps = run_cmd.run_steps - 1 | left_steps = run_cmd.run_steps - 1 | ||||
| @@ -338,8 +345,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| # clean watchpoint hit cache | # clean watchpoint hit cache | ||||
| if run_cmd.run_level == RunLevel.RECHECK.value: | if run_cmd.run_level == RunLevel.RECHECK.value: | ||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | ||||
| log.debug("Receive RunCMD. Clean watchpoint hit cache.") | |||||
| log.debug("Receive RunCMD. Clean watchpoint hit cache.") | |||||
| # update metadata state from sending to running | |||||
| metadata_stream.state = ServerStatus.RUNNING.value | |||||
| return event | return event | ||||
| @debugger_wrap | @debugger_wrap | ||||
| @@ -364,7 +372,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| ms_version, mindinsight.__version__) | ms_version, mindinsight.__version__) | ||||
| self._status = ServerStatus.MISMATCH | self._status = ServerStatus.MISMATCH | ||||
| reply.version_matched = False | reply.version_matched = False | ||||
| metadata_stream.state = 'mismatch' | |||||
| metadata_stream.state = ServerStatus.MISMATCH.value | |||||
| else: | else: | ||||
| log.info("version is matched.") | log.info("version is matched.") | ||||
| reply.version_matched = True | reply.version_matched = True | ||||
| @@ -43,6 +43,8 @@ class EventHandler(StreamHandlerBase): | |||||
| def has_pos(self, pos): | def has_pos(self, pos): | ||||
| """Get the event according to pos.""" | """Get the event according to pos.""" | ||||
| cur_flag, cur_idx = self._parse_pos(pos) | cur_flag, cur_idx = self._parse_pos(pos) | ||||
| if cur_flag not in [self._cur_flag, self._prev_flag]: | |||||
| cur_flag, cur_idx = self._cur_flag, 0 | |||||
| event = self._event_cache[cur_idx] | event = self._event_cache[cur_idx] | ||||
| if event is not None: | if event is not None: | ||||
| if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \ | if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \ | ||||
| @@ -95,7 +95,7 @@ class TrainingControlOperator: | |||||
| raise DebuggerContinueError( | raise DebuggerContinueError( | ||||
| "MindSpore is not ready to run or is running currently." | "MindSpore is not ready to run or is running currently." | ||||
| ) | ) | ||||
| metadata_stream.state = ServerStatus.RUNNING.value | |||||
| metadata_stream.state = ServerStatus.SENDING.value | |||||
| try: | try: | ||||
| self._validate_continue_params(params) | self._validate_continue_params(params) | ||||
| event = self._construct_run_event(params) | event = self._construct_run_event(params) | ||||
| @@ -214,9 +214,10 @@ class TrainingControlOperator: | |||||
| if metadata_stream.state != ServerStatus.RUNNING.value: | if metadata_stream.state != ServerStatus.RUNNING.value: | ||||
| log.error("The MindSpore is not running.") | log.error("The MindSpore is not running.") | ||||
| raise DebuggerPauseError("The MindSpore is not running.") | raise DebuggerPauseError("The MindSpore is not running.") | ||||
| metadata_stream.state = 'waiting' | |||||
| metadata_stream.state = ServerStatus.SENDING.value | |||||
| event = get_ack_reply() | event = get_ack_reply() | ||||
| event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) | event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) | ||||
| self._cache_store.clean_command() | |||||
| self._cache_store.put_command(event) | self._cache_store.put_command(event) | ||||
| metadata_stream.enable_recheck = False | metadata_stream.enable_recheck = False | ||||
| log.debug("Send the Pause command") | log.debug("Send the Pause command") | ||||
| @@ -230,7 +231,7 @@ class TrainingControlOperator: | |||||
| dict, metadata info. | dict, metadata info. | ||||
| """ | """ | ||||
| metadata_stream = self._metadata_stream | metadata_stream = self._metadata_stream | ||||
| metadata_stream.state = 'pending' | |||||
| metadata_stream.state = ServerStatus.SENDING.value | |||||
| self._cache_store.clean_data() | self._cache_store.clean_data() | ||||
| self._cache_store.clean_command() | self._cache_store.clean_command() | ||||
| event = get_ack_reply() | event = get_ack_reply() | ||||
| @@ -252,7 +253,7 @@ class TrainingControlOperator: | |||||
| if not metadata_stream.enable_recheck: | if not metadata_stream.enable_recheck: | ||||
| log.error("Recheck is not available.") | log.error("Recheck is not available.") | ||||
| raise DebuggerRecheckError("Recheck is not available.") | raise DebuggerRecheckError("Recheck is not available.") | ||||
| metadata_stream.state = ServerStatus.RUNNING.value | |||||
| metadata_stream.state = ServerStatus.SENDING.value | |||||
| metadata_stream.enable_recheck = False | metadata_stream.enable_recheck = False | ||||
| # send updated watchpoint and recheck command | # send updated watchpoint and recheck command | ||||
| try: | try: | ||||
| @@ -307,12 +307,13 @@ class TestAscendDebugger: | |||||
| body_data = {'mode': 'continue', | body_data = {'mode': 'continue', | ||||
| 'steps': -1} | 'steps': -1} | ||||
| res = get_request_result(app_client, url, body_data) | res = get_request_result(app_client, url, body_data) | ||||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||||
| # send pause command | # send pause command | ||||
| check_state(app_client, 'running') | |||||
| url = 'control' | url = 'control' | ||||
| body_data = {'mode': 'pause'} | body_data = {'mode': 'pause'} | ||||
| res = get_request_result(app_client, url, body_data) | res = get_request_result(app_client, url, body_data) | ||||
| assert res == {'metadata': {'state': 'waiting', 'enable_recheck': False}} | |||||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||||
| send_terminate_cmd(app_client) | send_terminate_cmd(app_client) | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -413,7 +414,7 @@ class TestGPUDebugger: | |||||
| 'level': 'node', | 'level': 'node', | ||||
| 'name': 'Default/TransData-op99'} | 'name': 'Default/TransData-op99'} | ||||
| res = get_request_result(app_client, url, body_data) | res = get_request_result(app_client, url, body_data) | ||||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||||
| # get metadata | # get metadata | ||||
| check_state(app_client) | check_state(app_client) | ||||
| url = 'retrieve' | url = 'retrieve' | ||||
| @@ -617,7 +618,7 @@ class TestMultiGraphDebugger: | |||||
| body_data = {'mode': 'continue'} | body_data = {'mode': 'continue'} | ||||
| body_data.update(params) | body_data.update(params) | ||||
| res = get_request_result(app_client, url, body_data) | res = get_request_result(app_client, url, body_data) | ||||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||||
| # get metadata | # get metadata | ||||
| check_state(app_client) | check_state(app_client) | ||||
| url = 'retrieve' | url = 'retrieve' | ||||
| @@ -669,7 +670,7 @@ def create_watchpoint_and_wait(app_client): | |||||
| body_data = {'mode': 'continue', | body_data = {'mode': 'continue', | ||||
| 'steps': 2} | 'steps': 2} | ||||
| res = get_request_result(app_client, url, body_data) | res = get_request_result(app_client, url, body_data) | ||||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||||
| assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}} | |||||
| # wait for server has received watchpoint hit | # wait for server has received watchpoint hit | ||||
| check_state(app_client) | check_state(app_client) | ||||
| @@ -51,9 +51,9 @@ class TestTrainingControlOperator: | |||||
| self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') | self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') | ||||
| @pytest.mark.parametrize('mode, cur_state, state', [ | @pytest.mark.parametrize('mode, cur_state, state', [ | ||||
| ('continue', 'waiting', 'running'), | |||||
| ('pause', 'running', 'waiting'), | |||||
| ('terminate', 'waiting', 'pending')]) | |||||
| ('continue', 'waiting', 'sending'), | |||||
| ('pause', 'running', 'sending'), | |||||
| ('terminate', 'waiting', 'sending')]) | |||||
| def test_control(self, mode, cur_state, state): | def test_control(self, mode, cur_state, state): | ||||
| """Test control request.""" | """Test control request.""" | ||||
| with mock.patch.object(MetadataHandler, 'state', cur_state): | with mock.patch.object(MetadataHandler, 'state', cur_state): | ||||