|
|
|
@@ -22,7 +22,7 @@ import _thread |
|
|
|
from mindinsight.conf import settings |
|
|
|
from mindinsight.debugger.common.log import LOGGER as logger |
|
|
|
from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \ |
|
|
|
DebuggerSessionNotFoundError |
|
|
|
DebuggerSessionNotFoundError, DebuggerSessionAlreadyExistError |
|
|
|
from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext |
|
|
|
from mindinsight.debugger.debugger_session import DebuggerSession |
|
|
|
|
|
|
|
@@ -31,7 +31,7 @@ class SessionManager: |
|
|
|
"""The server manager of debugger.""" |
|
|
|
|
|
|
|
ONLINE_TYPE = "ONLINE" |
|
|
|
MAX_SESSION_NUM = 2 |
|
|
|
MAX_OFFLINE_SESSION_NUM = 2 |
|
|
|
ONLINE_SESSION_ID = "0" |
|
|
|
_instance = None |
|
|
|
_cls_lock = threading.Lock() |
|
|
|
@@ -39,8 +39,8 @@ class SessionManager: |
|
|
|
def __init__(self): |
|
|
|
self.train_jobs = {} |
|
|
|
self.sessions = {} |
|
|
|
self.session_id = 1 |
|
|
|
self.online_session = None |
|
|
|
# The offline session id is start from 1, and the online session id is 0. |
|
|
|
self._next_session_id = 1 |
|
|
|
self._lock = threading.Lock() |
|
|
|
self._exiting = False |
|
|
|
enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False |
|
|
|
@@ -56,34 +56,29 @@ class SessionManager: |
|
|
|
return cls._instance |
|
|
|
|
|
|
|
def exit(self): |
|
|
|
""" |
|
|
|
Called when the gunicorn worker process is exiting. |
|
|
|
""" |
|
|
|
"""Called when the gunicorn worker process is exiting.""" |
|
|
|
with self._lock: |
|
|
|
logger.info("Start to exit sessions.") |
|
|
|
self._exiting = True |
|
|
|
for session in self.sessions: |
|
|
|
for session in self.sessions.values(): |
|
|
|
session.stop() |
|
|
|
self.online_session.stop() |
|
|
|
logger.info("Exited.") |
|
|
|
logger.info("Sessions exited.") |
|
|
|
|
|
|
|
def get_session(self, session_id): |
|
|
|
""" |
|
|
|
Get session by session id or get all session info. |
|
|
|
|
|
|
|
Args: |
|
|
|
session_id (Union[None, str]: The id of session. |
|
|
|
session_id (Union[None, str]): The id of session. |
|
|
|
|
|
|
|
Returns: |
|
|
|
DebuggerSession, debugger session object. |
|
|
|
""" |
|
|
|
with self._lock: |
|
|
|
if session_id == self.ONLINE_SESSION_ID and self.online_session is not None: |
|
|
|
return self.online_session |
|
|
|
|
|
|
|
if session_id in self.sessions: |
|
|
|
return self.sessions.get(session_id) |
|
|
|
|
|
|
|
logger.error('Debugger session %s is not found.', session_id) |
|
|
|
raise DebuggerSessionNotFoundError("{}".format(session_id)) |
|
|
|
|
|
|
|
def creat_session(self, session_type, train_job=None): |
|
|
|
@@ -104,11 +99,14 @@ class SessionManager: |
|
|
|
_thread.exit() |
|
|
|
|
|
|
|
if session_type == self.ONLINE_TYPE: |
|
|
|
if self.online_session is None: |
|
|
|
if self.ONLINE_SESSION_ID not in self.sessions: |
|
|
|
context = DebuggerServerContext(dbg_mode='online') |
|
|
|
self.online_session = DebuggerSession(context) |
|
|
|
self.online_session.start() |
|
|
|
return self.ONLINE_SESSION_ID |
|
|
|
online_session = DebuggerSession(context) |
|
|
|
online_session.start() |
|
|
|
self.sessions[self.ONLINE_SESSION_ID] = online_session |
|
|
|
return self.ONLINE_SESSION_ID |
|
|
|
logger.error('Online session is already exist, cannot create online session twice.') |
|
|
|
raise DebuggerSessionAlreadyExistError("online session") |
|
|
|
|
|
|
|
if train_job in self.train_jobs: |
|
|
|
return self.train_jobs.get(train_job) |
|
|
|
@@ -121,21 +119,21 @@ class SessionManager: |
|
|
|
context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) |
|
|
|
session = DebuggerSession(context) |
|
|
|
session.start() |
|
|
|
session_id = str(self.session_id) |
|
|
|
session_id = str(self._next_session_id) |
|
|
|
self.sessions[session_id] = session |
|
|
|
self.train_jobs[train_job] = session_id |
|
|
|
self.session_id += 1 |
|
|
|
self._next_session_id += 1 |
|
|
|
return session_id |
|
|
|
|
|
|
|
def delete_session(self, session_id): |
|
|
|
"""Delete session by session id.""" |
|
|
|
with self._lock: |
|
|
|
if session_id == self.ONLINE_SESSION_ID: |
|
|
|
self.online_session.stop() |
|
|
|
self.online_session = None |
|
|
|
return |
|
|
|
logger.error('Online session can not be deleted.') |
|
|
|
raise ValueError("Online session can not be delete.") |
|
|
|
|
|
|
|
if session_id not in self.sessions: |
|
|
|
logger.error('Debugger session %s is not found', session_id) |
|
|
|
raise DebuggerSessionNotFoundError("session id {}".format(session_id)) |
|
|
|
|
|
|
|
session = self.sessions.get(session_id) |
|
|
|
@@ -144,29 +142,37 @@ class SessionManager: |
|
|
|
self.train_jobs.pop(session.train_job) |
|
|
|
return |
|
|
|
|
|
|
|
def get_sessions(self): |
|
|
|
"""get all sessions""" |
|
|
|
def get_train_jobs(self): |
|
|
|
"""Get all train jobs.""" |
|
|
|
return {"train_jobs": self.train_jobs} |
|
|
|
|
|
|
|
def _check_session_num(self): |
|
|
|
"""Check the amount of sessions.""" |
|
|
|
if len(self.sessions) >= self.MAX_SESSION_NUM: |
|
|
|
session_limitation = self.MAX_OFFLINE_SESSION_NUM |
|
|
|
if self.ONLINE_SESSION_ID in self.sessions: |
|
|
|
session_limitation += 1 |
|
|
|
if len(self.sessions) >= session_limitation: |
|
|
|
logger.warning('Offline debugger session num %s is reach the limitation %s', len(self.sessions), |
|
|
|
session_limitation) |
|
|
|
raise DebuggerSessionNumOverBoundError() |
|
|
|
|
|
|
|
|
|
|
|
def validate_and_normalize_path(path): |
|
|
|
"""Validate and normalize_path""" |
|
|
|
if not path: |
|
|
|
raise ValueError("The path is invalid!") |
|
|
|
logger.error('The whole path of dump directory is None.') |
|
|
|
raise ValueError("The whole path of dump directory is None.") |
|
|
|
|
|
|
|
path_str = str(path) |
|
|
|
|
|
|
|
if not path_str.startswith("/"): |
|
|
|
raise ValueError("The path is invalid!") |
|
|
|
logger.error('The whole path of dump directory is not start with \'/\'') |
|
|
|
raise ValueError("The whole path of dump directory is not start with \'/\'") |
|
|
|
|
|
|
|
try: |
|
|
|
normalized_path = os.path.realpath(path) |
|
|
|
except ValueError: |
|
|
|
raise ValueError("The path is invalid!") |
|
|
|
logger.error('The whole path of dump directory is invalid.') |
|
|
|
raise ValueError("The whole path of dump directory is invalid.") |
|
|
|
|
|
|
|
return normalized_path |