|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Implement the session manager."""
- import os
- import threading
- from urllib.parse import unquote
-
- import _thread
-
- from mindinsight.conf import settings
- from mindinsight.debugger.common.log import LOGGER as logger
- from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \
- DebuggerSessionNotFoundError, DebuggerOnlineSessionUnavailable
- from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext
- from mindinsight.debugger.debugger_session import DebuggerSession
-
-
- class SessionManager:
- """The server manager of debugger."""
-
- ONLINE_TYPE = "ONLINE"
- MAX_OFFLINE_SESSION_NUM = 2
- ONLINE_SESSION_ID = "0"
- _instance = None
- _cls_lock = threading.Lock()
-
- def __init__(self):
- self.train_jobs = {}
- self.sessions = {}
- # The offline session id is start from 1, and the online session id is 0.
- self._next_session_id = 1
- self._lock = threading.Lock()
- self._exiting = False
- enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
- if enable_debugger:
- self._create_online_session()
-
- @classmethod
- def get_instance(cls):
- """Get the singleton instance."""
- with cls._cls_lock:
- if cls._instance is None:
- cls._instance = SessionManager()
- return cls._instance
-
- def exit(self):
- """Called when the gunicorn worker process is exiting."""
- with self._lock:
- logger.info("Start to exit sessions.")
- self._exiting = True
- for session in self.sessions.values():
- session.stop()
- logger.info("Sessions exited.")
-
- def get_session(self, session_id):
- """
- Get session by session id or get all session info.
-
- Args:
- session_id (Union[None, str]): The id of session.
-
- Returns:
- DebuggerSession, debugger session object.
- """
- with self._lock:
- if session_id in self.sessions:
- return self.sessions.get(session_id)
-
- logger.error('Debugger session %s is not found.', session_id)
- raise DebuggerSessionNotFoundError("{}".format(session_id))
-
- def _create_online_session(self):
- """Create online session."""
- with self._lock:
- context = DebuggerServerContext(dbg_mode='online')
- online_session = DebuggerSession(context)
- online_session.start()
- self.sessions[self.ONLINE_SESSION_ID] = online_session
-
- def _create_offline_session(self, train_job):
- """Create offline session."""
- self._check_session_num()
- if not isinstance(train_job, str):
- logger.error('The train job path should be string.')
- raise ValueError("The train job path should be string.")
- summary_base_dir = settings.SUMMARY_BASE_DIR
- unquote_path = unquote(train_job, errors='strict')
- whole_path = os.path.join(summary_base_dir, unquote_path)
- normalized_path = validate_and_normalize_path(whole_path)
- context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path)
- session = DebuggerSession(context)
- session.start()
- session_id = str(self._next_session_id)
- self.sessions[session_id] = session
- self.train_jobs[train_job] = session_id
- self._next_session_id += 1
- return session_id
-
- def create_session(self, session_type, train_job=None):
- """
- Create the session by the train job info or session type if the session doesn't exist.
-
- Args:
- session_type (str): The session_type.
- train_job (str): The train job info.
-
- Returns:
- str, session id.
- """
- with self._lock:
- if self._exiting:
- logger.info(
- "System is exiting, will terminate the thread.")
- _thread.exit()
-
- if session_type == self.ONLINE_TYPE:
- if self.ONLINE_SESSION_ID not in self.sessions:
- logger.error(
- 'Online session is unavailable, set --enable-debugger as true/1 to enable debugger '
- 'when start Mindinsight server.')
- raise DebuggerOnlineSessionUnavailable()
- return self.ONLINE_SESSION_ID
-
- if train_job in self.train_jobs:
- return self.train_jobs.get(train_job)
-
- return self._create_offline_session(train_job)
-
- def delete_session(self, session_id):
- """Delete session by session id."""
- with self._lock:
- if session_id == self.ONLINE_SESSION_ID:
- logger.error('Online session can not be deleted.')
- raise ValueError("Online session can not be delete.")
-
- if session_id not in self.sessions:
- logger.error('Debugger session %s is not found', session_id)
- raise DebuggerSessionNotFoundError("session id {}".format(session_id))
-
- session = self.sessions.get(session_id)
- session.stop()
- self.sessions.pop(session_id)
- self.train_jobs.pop(session.train_job)
- return
-
- def get_train_jobs(self):
- """Get all train jobs."""
- return {"train_jobs": self.train_jobs}
-
- def _check_session_num(self):
- """Check the amount of sessions."""
- session_limitation = self.MAX_OFFLINE_SESSION_NUM
- if self.ONLINE_SESSION_ID in self.sessions:
- session_limitation += 1
- if len(self.sessions) >= session_limitation:
- logger.warning('Offline debugger session num %s is reach the limitation %s', len(self.sessions),
- session_limitation)
- raise DebuggerSessionNumOverBoundError()
-
-
- def validate_and_normalize_path(path):
- """Validate and normalize_path"""
- if not path:
- logger.error('The whole path of dump directory is None.')
- raise ValueError("The whole path of dump directory is None.")
-
- path_str = str(path)
-
- if not path_str.startswith("/"):
- logger.error('The whole path of dump directory is not start with \'/\'')
- raise ValueError("The whole path of dump directory is not start with \'/\'")
-
- try:
- normalized_path = os.path.realpath(path)
- except ValueError:
- logger.error('The whole path of dump directory is invalid.')
- raise ValueError("The whole path of dump directory is invalid.")
-
- return normalized_path
|