From: @jiang-shuqiang Reviewed-by: @yelihua,@wenkai_dist Signed-off-by: @wenkai_distpull/1308/MERGE
| @@ -375,7 +375,7 @@ def set_recommended_watch_points(session_id): | |||||
| @BLUEPRINT.route("/debugger/sessions", methods=["POST"]) | @BLUEPRINT.route("/debugger/sessions", methods=["POST"]) | ||||
| def creat_session(): | |||||
| def create_session(): | |||||
| """ | """ | ||||
| Get session id if session exist, else create a session. | Get session id if session exist, else create a session. | ||||
| @@ -383,22 +383,22 @@ def creat_session(): | |||||
| str, session id. | str, session id. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/get-session | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions | |||||
| """ | """ | ||||
| body = _read_post_request(request) | body = _read_post_request(request) | ||||
| summary_dir = body.get('dump_dir') | summary_dir = body.get('dump_dir') | ||||
| session_type = body.get('session_type') | session_type = body.get('session_type') | ||||
| reply = _wrap_reply(SessionManager.get_instance().creat_session, session_type, summary_dir) | |||||
| reply = _wrap_reply(SessionManager.get_instance().create_session, session_type, summary_dir) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/sessions", methods=["GET"]) | @BLUEPRINT.route("/debugger/sessions", methods=["GET"]) | ||||
| def get_sessions(): | |||||
| def get_train_jobs(): | |||||
| """ | """ | ||||
| Check the current active sessions. | Check the current active sessions. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/check-sessions | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions | |||||
| """ | """ | ||||
| reply = _wrap_reply(SessionManager.get_instance().get_train_jobs) | reply = _wrap_reply(SessionManager.get_instance().get_train_jobs) | ||||
| return reply | return reply | ||||
| @@ -52,7 +52,7 @@ class DebuggerErrors(DebuggerErrorCodes): | |||||
| DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR | DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR | ||||
| DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR | DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR | ||||
| DEBUGGER_SESSION_ALREADY_EXIST_ERROR = 2 | _DEBUGGER_SESSION_ERROR | |||||
| DEBUGGER_ONLINE_SESSION_UNAVAILABLE = 2 | _DEBUGGER_SESSION_ERROR | |||||
| @unique | @unique | ||||
| @@ -80,4 +80,4 @@ class DebuggerErrorMsg(Enum): | |||||
| DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation." | DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation." | ||||
| DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found." | DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found." | ||||
| DEBUGGER_SESSION_ALREADY_EXIST_ERROR = "Session {} already exist." | |||||
| DEBUGGER_ONLINE_SESSION_UNAVAILABLE = "Online session is unavailable." | |||||
| @@ -247,12 +247,12 @@ class DebuggerSessionNotFoundError(MindInsightException): | |||||
| ) | ) | ||||
| class DebuggerSessionAlreadyExistError(MindInsightException): | |||||
| class DebuggerOnlineSessionUnavailable(MindInsightException): | |||||
| """The error of that the session already exist.""" | """The error of that the session already exist.""" | ||||
| def __init__(self, msg): | |||||
| super(DebuggerSessionAlreadyExistError, self).__init__( | |||||
| error=DebuggerErrors.DEBUGGER_SESSION_ALREADY_EXIST_ERROR, | |||||
| message=DebuggerErrorMsg.DEBUGGER_SESSION_ALREADY_EXIST_ERROR.value.format(msg), | |||||
| def __init__(self): | |||||
| super(DebuggerOnlineSessionUnavailable, self).__init__( | |||||
| error=DebuggerErrors.DEBUGGER_ONLINE_SESSION_UNAVAILABLE, | |||||
| message=DebuggerErrorMsg.DEBUGGER_ONLINE_SESSION_UNAVAILABLE.value, | |||||
| http_code=400 | http_code=400 | ||||
| ) | ) | ||||
| @@ -22,7 +22,7 @@ import _thread | |||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| from mindinsight.debugger.common.log import LOGGER as logger | from mindinsight.debugger.common.log import LOGGER as logger | ||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \ | from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \ | ||||
| DebuggerSessionNotFoundError, DebuggerSessionAlreadyExistError | |||||
| DebuggerSessionNotFoundError, DebuggerOnlineSessionUnavailable | |||||
| from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext | from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext | ||||
| from mindinsight.debugger.debugger_session import DebuggerSession | from mindinsight.debugger.debugger_session import DebuggerSession | ||||
| @@ -45,7 +45,7 @@ class SessionManager: | |||||
| self._exiting = False | self._exiting = False | ||||
| enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | ||||
| if enable_debugger: | if enable_debugger: | ||||
| self.creat_session(self.ONLINE_TYPE) | |||||
| self._create_online_session() | |||||
| @classmethod | @classmethod | ||||
| def get_instance(cls): | def get_instance(cls): | ||||
| @@ -81,9 +81,36 @@ class SessionManager: | |||||
| logger.error('Debugger session %s is not found.', session_id) | logger.error('Debugger session %s is not found.', session_id) | ||||
| raise DebuggerSessionNotFoundError("{}".format(session_id)) | raise DebuggerSessionNotFoundError("{}".format(session_id)) | ||||
| def creat_session(self, session_type, train_job=None): | |||||
| def _create_online_session(self): | |||||
| """Create online session.""" | |||||
| with self._lock: | |||||
| context = DebuggerServerContext(dbg_mode='online') | |||||
| online_session = DebuggerSession(context) | |||||
| online_session.start() | |||||
| self.sessions[self.ONLINE_SESSION_ID] = online_session | |||||
| def _create_offline_session(self, train_job): | |||||
| """Create offline session.""" | |||||
| self._check_session_num() | |||||
| if not isinstance(train_job, str): | |||||
| logger.error('The train job path should be string.') | |||||
| raise ValueError("The train job path should be string.") | |||||
| summary_base_dir = settings.SUMMARY_BASE_DIR | |||||
| unquote_path = unquote(train_job, errors='strict') | |||||
| whole_path = os.path.join(summary_base_dir, unquote_path) | |||||
| normalized_path = validate_and_normalize_path(whole_path) | |||||
| context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) | |||||
| session = DebuggerSession(context) | |||||
| session.start() | |||||
| session_id = str(self._next_session_id) | |||||
| self.sessions[session_id] = session | |||||
| self.train_jobs[train_job] = session_id | |||||
| self._next_session_id += 1 | |||||
| return session_id | |||||
| def create_session(self, session_type, train_job=None): | |||||
| """ | """ | ||||
| Create session by the train job info. | |||||
| Create the session by the train job info or session type if the session doesn't exist. | |||||
| Args: | Args: | ||||
| session_type (str): The session_type. | session_type (str): The session_type. | ||||
| @@ -100,30 +127,16 @@ class SessionManager: | |||||
| if session_type == self.ONLINE_TYPE: | if session_type == self.ONLINE_TYPE: | ||||
| if self.ONLINE_SESSION_ID not in self.sessions: | if self.ONLINE_SESSION_ID not in self.sessions: | ||||
| context = DebuggerServerContext(dbg_mode='online') | |||||
| online_session = DebuggerSession(context) | |||||
| online_session.start() | |||||
| self.sessions[self.ONLINE_SESSION_ID] = online_session | |||||
| return self.ONLINE_SESSION_ID | |||||
| logger.error('Online session is already exist, cannot create online session twice.') | |||||
| raise DebuggerSessionAlreadyExistError("online session") | |||||
| logger.error( | |||||
| 'Online session is unavailable, set --enable-debugger as true/1 to enable debugger ' | |||||
| 'when start Mindinsight server.') | |||||
| raise DebuggerOnlineSessionUnavailable() | |||||
| return self.ONLINE_SESSION_ID | |||||
| if train_job in self.train_jobs: | if train_job in self.train_jobs: | ||||
| return self.train_jobs.get(train_job) | return self.train_jobs.get(train_job) | ||||
| self._check_session_num() | |||||
| summary_base_dir = settings.SUMMARY_BASE_DIR | |||||
| unquote_path = unquote(train_job, errors='strict') | |||||
| whole_path = os.path.join(summary_base_dir, unquote_path) | |||||
| normalized_path = validate_and_normalize_path(whole_path) | |||||
| context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) | |||||
| session = DebuggerSession(context) | |||||
| session.start() | |||||
| session_id = str(self._next_session_id) | |||||
| self.sessions[session_id] = session | |||||
| self.train_jobs[train_job] = session_id | |||||
| self._next_session_id += 1 | |||||
| return session_id | |||||
| return self._create_offline_session(train_job) | |||||
| def delete_session(self, session_id): | def delete_session(self, session_id): | ||||
| """Delete session by session id.""" | """Delete session by session id.""" | ||||
| @@ -46,6 +46,10 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| # whether the watchpoint list has been changed since last step | # whether the watchpoint list has been changed since last step | ||||
| self._outdated = False | self._outdated = False | ||||
| def set_outdated(self): | |||||
| """"Set outdated as True.""" | |||||
| self._outdated = True | |||||
| def put(self, value): | def put(self, value): | ||||
| """ | """ | ||||
| Put Watchpoint into watchpoint handler. | Put Watchpoint into watchpoint handler. | ||||
| @@ -306,7 +306,8 @@ class TrainingControlOperator: | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id) | self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id) | ||||
| self._cache_store.clean_data() | self._cache_store.clean_data() | ||||
| self._cache_store.clean_command() | self._cache_store.clean_command() | ||||
| metadata_stream.enable_recheck = False | |||||
| metadata_stream.enable_recheck = True | |||||
| metadata_stream.state = ServerStatus.WAITING.value | metadata_stream.state = ServerStatus.WAITING.value | ||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT).set_outdated() | |||||
| log.debug("Send the Change_training_step CMD.") | log.debug("Send the Change_training_step CMD.") | ||||
| return metadata_stream.get(['state', 'enable_recheck', 'step']) | return metadata_stream.get(['state', 'enable_recheck', 'step']) | ||||
| @@ -73,4 +73,4 @@ class TestTrainingControlOperator: | |||||
| with mock.patch.object(MetadataHandler, 'max_step_num', 10), \ | with mock.patch.object(MetadataHandler, 'max_step_num', 10), \ | ||||
| mock.patch.object(MetadataHandler, 'debugger_type', 'offline'): | mock.patch.object(MetadataHandler, 'debugger_type', 'offline'): | ||||
| res = self._server.control(mode=mode, params={'steps': 9}) | res = self._server.control(mode=mode, params={'steps': 9}) | ||||
| assert res == {'metadata': {'enable_recheck': False, 'state': state, 'step': 9}} | |||||
| assert res == {'metadata': {'enable_recheck': True, 'state': state, 'step': 9}} | |||||