diff --git a/mindinsight/backend/debugger/debugger_api.py b/mindinsight/backend/debugger/debugger_api.py index 97d7387b..2c900f2b 100644 --- a/mindinsight/backend/debugger/debugger_api.py +++ b/mindinsight/backend/debugger/debugger_api.py @@ -375,7 +375,7 @@ def set_recommended_watch_points(session_id): @BLUEPRINT.route("/debugger/sessions", methods=["POST"]) -def creat_session(): +def create_session(): """ Get session id if session exist, else create a session. @@ -383,22 +383,22 @@ def creat_session(): str, session id. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/get-session + >>> POST http://xxxx/v1/mindinsight/debugger/sessions """ body = _read_post_request(request) summary_dir = body.get('dump_dir') session_type = body.get('session_type') - reply = _wrap_reply(SessionManager.get_instance().creat_session, session_type, summary_dir) + reply = _wrap_reply(SessionManager.get_instance().create_session, session_type, summary_dir) return reply @BLUEPRINT.route("/debugger/sessions", methods=["GET"]) -def get_sessions(): +def get_train_jobs(): """ Check the current active sessions. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/check-sessions + >>> POST http://xxxx/v1/mindinsight/debugger/sessions """ reply = _wrap_reply(SessionManager.get_instance().get_train_jobs) return reply diff --git a/mindinsight/debugger/common/exceptions/error_code.py b/mindinsight/debugger/common/exceptions/error_code.py index ff426882..a0483978 100644 --- a/mindinsight/debugger/common/exceptions/error_code.py +++ b/mindinsight/debugger/common/exceptions/error_code.py @@ -52,7 +52,7 @@ class DebuggerErrors(DebuggerErrorCodes): DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR - DEBUGGER_SESSION_ALREADY_EXIST_ERROR = 2 | _DEBUGGER_SESSION_ERROR + DEBUGGER_ONLINE_SESSION_UNAVAILABLE = 2 | _DEBUGGER_SESSION_ERROR @unique @@ -80,4 +80,4 @@ class DebuggerErrorMsg(Enum): DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation." DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found." - DEBUGGER_SESSION_ALREADY_EXIST_ERROR = "Session {} already exist." + DEBUGGER_ONLINE_SESSION_UNAVAILABLE = "Online session is unavailable." diff --git a/mindinsight/debugger/common/exceptions/exceptions.py b/mindinsight/debugger/common/exceptions/exceptions.py index 982eb08b..37e52734 100644 --- a/mindinsight/debugger/common/exceptions/exceptions.py +++ b/mindinsight/debugger/common/exceptions/exceptions.py @@ -247,12 +247,12 @@ class DebuggerSessionNotFoundError(MindInsightException): ) -class DebuggerSessionAlreadyExistError(MindInsightException): +class DebuggerOnlineSessionUnavailable(MindInsightException): """The error of that the session already exist.""" - def __init__(self, msg): - super(DebuggerSessionAlreadyExistError, self).__init__( - error=DebuggerErrors.DEBUGGER_SESSION_ALREADY_EXIST_ERROR, - message=DebuggerErrorMsg.DEBUGGER_SESSION_ALREADY_EXIST_ERROR.value.format(msg), + def __init__(self): + super(DebuggerOnlineSessionUnavailable, self).__init__( + error=DebuggerErrors.DEBUGGER_ONLINE_SESSION_UNAVAILABLE, + message=DebuggerErrorMsg.DEBUGGER_ONLINE_SESSION_UNAVAILABLE.value, http_code=400 ) diff --git a/mindinsight/debugger/session_manager.py b/mindinsight/debugger/session_manager.py index a990f6b6..89cd428e 100644 --- a/mindinsight/debugger/session_manager.py +++ b/mindinsight/debugger/session_manager.py @@ -22,7 +22,7 @@ import _thread from mindinsight.conf import settings from mindinsight.debugger.common.log import LOGGER as logger from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \ - DebuggerSessionNotFoundError, DebuggerSessionAlreadyExistError + DebuggerSessionNotFoundError, DebuggerOnlineSessionUnavailable from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext from mindinsight.debugger.debugger_session import DebuggerSession @@ -45,7 +45,7 @@ class SessionManager: self._exiting = False enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False if enable_debugger: - self.creat_session(self.ONLINE_TYPE) + self._create_online_session() @classmethod def get_instance(cls): @@ -81,9 +81,36 @@ class SessionManager: logger.error('Debugger session %s is not found.', session_id) raise DebuggerSessionNotFoundError("{}".format(session_id)) - def creat_session(self, session_type, train_job=None): + def _create_online_session(self): + """Create online session.""" + with self._lock: + context = DebuggerServerContext(dbg_mode='online') + online_session = DebuggerSession(context) + online_session.start() + self.sessions[self.ONLINE_SESSION_ID] = online_session + + def _create_offline_session(self, train_job): + """Create offline session.""" + self._check_session_num() + if not isinstance(train_job, str): + logger.error('The train job path should be string.') + raise ValueError("The train job path should be string.") + summary_base_dir = settings.SUMMARY_BASE_DIR + unquote_path = unquote(train_job, errors='strict') + whole_path = os.path.join(summary_base_dir, unquote_path) + normalized_path = validate_and_normalize_path(whole_path) + context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) + session = DebuggerSession(context) + session.start() + session_id = str(self._next_session_id) + self.sessions[session_id] = session + self.train_jobs[train_job] = session_id + self._next_session_id += 1 + return session_id + + def create_session(self, session_type, train_job=None): """ - Create session by the train job info. + Create the session by the train job info or session type if the session doesn't exist. Args: session_type (str): The session_type. @@ -100,30 +127,16 @@ class SessionManager: if session_type == self.ONLINE_TYPE: if self.ONLINE_SESSION_ID not in self.sessions: - context = DebuggerServerContext(dbg_mode='online') - online_session = DebuggerSession(context) - online_session.start() - self.sessions[self.ONLINE_SESSION_ID] = online_session - return self.ONLINE_SESSION_ID - logger.error('Online session is already exist, cannot create online session twice.') - raise DebuggerSessionAlreadyExistError("online session") + logger.error( + 'Online session is unavailable, set --enable-debugger as true/1 to enable debugger ' + 'when start Mindinsight server.') + raise DebuggerOnlineSessionUnavailable() + return self.ONLINE_SESSION_ID if train_job in self.train_jobs: return self.train_jobs.get(train_job) - self._check_session_num() - summary_base_dir = settings.SUMMARY_BASE_DIR - unquote_path = unquote(train_job, errors='strict') - whole_path = os.path.join(summary_base_dir, unquote_path) - normalized_path = validate_and_normalize_path(whole_path) - context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) - session = DebuggerSession(context) - session.start() - session_id = str(self._next_session_id) - self.sessions[session_id] = session - self.train_jobs[train_job] = session_id - self._next_session_id += 1 - return session_id + return self._create_offline_session(train_job) def delete_session(self, session_id): """Delete session by session id.""" diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py index ed76891f..a46045fc 100644 --- a/mindinsight/debugger/stream_handler/watchpoint_handler.py +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -46,6 +46,10 @@ class WatchpointHandler(StreamHandlerBase): # whether the watchpoint list has been changed since last step self._outdated = False + def set_outdated(self): + """"Set outdated as True.""" + self._outdated = True + def put(self, value): """ Put Watchpoint into watchpoint handler. diff --git a/mindinsight/debugger/stream_operator/training_control_operator.py b/mindinsight/debugger/stream_operator/training_control_operator.py index 5fa49924..3bb2ac85 100644 --- a/mindinsight/debugger/stream_operator/training_control_operator.py +++ b/mindinsight/debugger/stream_operator/training_control_operator.py @@ -306,7 +306,8 @@ class TrainingControlOperator: self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id) self._cache_store.clean_data() self._cache_store.clean_command() - metadata_stream.enable_recheck = False + metadata_stream.enable_recheck = True metadata_stream.state = ServerStatus.WAITING.value + self._cache_store.get_stream_handler(Streams.WATCHPOINT).set_outdated() log.debug("Send the Change_training_step CMD.") return metadata_stream.get(['state', 'enable_recheck', 'step']) diff --git a/tests/ut/debugger/stream_operator/test_training_control_operator.py b/tests/ut/debugger/stream_operator/test_training_control_operator.py index 27553497..edf0982f 100644 --- a/tests/ut/debugger/stream_operator/test_training_control_operator.py +++ b/tests/ut/debugger/stream_operator/test_training_control_operator.py @@ -73,4 +73,4 @@ class TestTrainingControlOperator: with mock.patch.object(MetadataHandler, 'max_step_num', 10), \ mock.patch.object(MetadataHandler, 'debugger_type', 'offline'): res = self._server.control(mode=mode, params={'steps': 9}) - assert res == {'metadata': {'enable_recheck': False, 'state': state, 'step': 9}} + assert res == {'metadata': {'enable_recheck': True, 'state': state, 'step': 9}}