From: @jiang-shuqiang Reviewed-by: @yelihua,@wenkai_dist Signed-off-by: @wenkai_distpull/1308/MERGE
| @@ -375,7 +375,7 @@ def set_recommended_watch_points(session_id): | |||
| @BLUEPRINT.route("/debugger/sessions", methods=["POST"]) | |||
| def creat_session(): | |||
| def create_session(): | |||
| """ | |||
| Get session id if session exist, else create a session. | |||
| @@ -383,22 +383,22 @@ def creat_session(): | |||
| str, session id. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/get-session | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions | |||
| """ | |||
| body = _read_post_request(request) | |||
| summary_dir = body.get('dump_dir') | |||
| session_type = body.get('session_type') | |||
| reply = _wrap_reply(SessionManager.get_instance().creat_session, session_type, summary_dir) | |||
| reply = _wrap_reply(SessionManager.get_instance().create_session, session_type, summary_dir) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/sessions", methods=["GET"]) | |||
| def get_sessions(): | |||
| def get_train_jobs(): | |||
| """ | |||
| Check the current active sessions. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/check-sessions | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions | |||
| """ | |||
| reply = _wrap_reply(SessionManager.get_instance().get_train_jobs) | |||
| return reply | |||
| @@ -52,7 +52,7 @@ class DebuggerErrors(DebuggerErrorCodes): | |||
| DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR | |||
| DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR | |||
| DEBUGGER_SESSION_ALREADY_EXIST_ERROR = 2 | _DEBUGGER_SESSION_ERROR | |||
| DEBUGGER_ONLINE_SESSION_UNAVAILABLE = 2 | _DEBUGGER_SESSION_ERROR | |||
| @unique | |||
| @@ -80,4 +80,4 @@ class DebuggerErrorMsg(Enum): | |||
| DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation." | |||
| DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found." | |||
| DEBUGGER_SESSION_ALREADY_EXIST_ERROR = "Session {} already exist." | |||
| DEBUGGER_ONLINE_SESSION_UNAVAILABLE = "Online session is unavailable." | |||
| @@ -247,12 +247,12 @@ class DebuggerSessionNotFoundError(MindInsightException): | |||
| ) | |||
| class DebuggerSessionAlreadyExistError(MindInsightException): | |||
| class DebuggerOnlineSessionUnavailable(MindInsightException): | |||
| """The error of that the session already exist.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerSessionAlreadyExistError, self).__init__( | |||
| error=DebuggerErrors.DEBUGGER_SESSION_ALREADY_EXIST_ERROR, | |||
| message=DebuggerErrorMsg.DEBUGGER_SESSION_ALREADY_EXIST_ERROR.value.format(msg), | |||
| def __init__(self): | |||
| super(DebuggerOnlineSessionUnavailable, self).__init__( | |||
| error=DebuggerErrors.DEBUGGER_ONLINE_SESSION_UNAVAILABLE, | |||
| message=DebuggerErrorMsg.DEBUGGER_ONLINE_SESSION_UNAVAILABLE.value, | |||
| http_code=400 | |||
| ) | |||
| @@ -22,7 +22,7 @@ import _thread | |||
| from mindinsight.conf import settings | |||
| from mindinsight.debugger.common.log import LOGGER as logger | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \ | |||
| DebuggerSessionNotFoundError, DebuggerSessionAlreadyExistError | |||
| DebuggerSessionNotFoundError, DebuggerOnlineSessionUnavailable | |||
| from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext | |||
| from mindinsight.debugger.debugger_session import DebuggerSession | |||
| @@ -45,7 +45,7 @@ class SessionManager: | |||
| self._exiting = False | |||
| enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | |||
| if enable_debugger: | |||
| self.creat_session(self.ONLINE_TYPE) | |||
| self._create_online_session() | |||
| @classmethod | |||
| def get_instance(cls): | |||
| @@ -81,9 +81,36 @@ class SessionManager: | |||
| logger.error('Debugger session %s is not found.', session_id) | |||
| raise DebuggerSessionNotFoundError("{}".format(session_id)) | |||
| def creat_session(self, session_type, train_job=None): | |||
| def _create_online_session(self): | |||
| """Create online session.""" | |||
| with self._lock: | |||
| context = DebuggerServerContext(dbg_mode='online') | |||
| online_session = DebuggerSession(context) | |||
| online_session.start() | |||
| self.sessions[self.ONLINE_SESSION_ID] = online_session | |||
| def _create_offline_session(self, train_job): | |||
| """Create offline session.""" | |||
| self._check_session_num() | |||
| if not isinstance(train_job, str): | |||
| logger.error('The train job path should be string.') | |||
| raise ValueError("The train job path should be string.") | |||
| summary_base_dir = settings.SUMMARY_BASE_DIR | |||
| unquote_path = unquote(train_job, errors='strict') | |||
| whole_path = os.path.join(summary_base_dir, unquote_path) | |||
| normalized_path = validate_and_normalize_path(whole_path) | |||
| context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) | |||
| session = DebuggerSession(context) | |||
| session.start() | |||
| session_id = str(self._next_session_id) | |||
| self.sessions[session_id] = session | |||
| self.train_jobs[train_job] = session_id | |||
| self._next_session_id += 1 | |||
| return session_id | |||
| def create_session(self, session_type, train_job=None): | |||
| """ | |||
| Create session by the train job info. | |||
| Create the session by the train job info or session type if the session doesn't exist. | |||
| Args: | |||
| session_type (str): The session_type. | |||
| @@ -100,30 +127,16 @@ class SessionManager: | |||
| if session_type == self.ONLINE_TYPE: | |||
| if self.ONLINE_SESSION_ID not in self.sessions: | |||
| context = DebuggerServerContext(dbg_mode='online') | |||
| online_session = DebuggerSession(context) | |||
| online_session.start() | |||
| self.sessions[self.ONLINE_SESSION_ID] = online_session | |||
| return self.ONLINE_SESSION_ID | |||
| logger.error('Online session is already exist, cannot create online session twice.') | |||
| raise DebuggerSessionAlreadyExistError("online session") | |||
| logger.error( | |||
| 'Online session is unavailable, set --enable-debugger as true/1 to enable debugger ' | |||
| 'when start Mindinsight server.') | |||
| raise DebuggerOnlineSessionUnavailable() | |||
| return self.ONLINE_SESSION_ID | |||
| if train_job in self.train_jobs: | |||
| return self.train_jobs.get(train_job) | |||
| self._check_session_num() | |||
| summary_base_dir = settings.SUMMARY_BASE_DIR | |||
| unquote_path = unquote(train_job, errors='strict') | |||
| whole_path = os.path.join(summary_base_dir, unquote_path) | |||
| normalized_path = validate_and_normalize_path(whole_path) | |||
| context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) | |||
| session = DebuggerSession(context) | |||
| session.start() | |||
| session_id = str(self._next_session_id) | |||
| self.sessions[session_id] = session | |||
| self.train_jobs[train_job] = session_id | |||
| self._next_session_id += 1 | |||
| return session_id | |||
| return self._create_offline_session(train_job) | |||
| def delete_session(self, session_id): | |||
| """Delete session by session id.""" | |||
| @@ -46,6 +46,10 @@ class WatchpointHandler(StreamHandlerBase): | |||
| # whether the watchpoint list has been changed since last step | |||
| self._outdated = False | |||
| def set_outdated(self): | |||
| """"Set outdated as True.""" | |||
| self._outdated = True | |||
| def put(self, value): | |||
| """ | |||
| Put Watchpoint into watchpoint handler. | |||
| @@ -306,7 +306,8 @@ class TrainingControlOperator: | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id) | |||
| self._cache_store.clean_data() | |||
| self._cache_store.clean_command() | |||
| metadata_stream.enable_recheck = False | |||
| metadata_stream.enable_recheck = True | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT).set_outdated() | |||
| log.debug("Send the Change_training_step CMD.") | |||
| return metadata_stream.get(['state', 'enable_recheck', 'step']) | |||
| @@ -73,4 +73,4 @@ class TestTrainingControlOperator: | |||
| with mock.patch.object(MetadataHandler, 'max_step_num', 10), \ | |||
| mock.patch.object(MetadataHandler, 'debugger_type', 'offline'): | |||
| res = self._server.control(mode=mode, params={'steps': 9}) | |||
| assert res == {'metadata': {'enable_recheck': False, 'state': state, 'step': 9}} | |||
| assert res == {'metadata': {'enable_recheck': True, 'state': state, 'step': 9}} | |||