|
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Implement the session manager."""
- import os
- import threading
- from urllib.parse import unquote
-
- import _thread
-
- from mindinsight.conf import settings
- from mindinsight.debugger.common.log import LOGGER as logger
- from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \
- DebuggerSessionNotFoundError
- from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext
- from mindinsight.debugger.debugger_session import DebuggerSession
-
-
- class SessionManager:
- """The server manager of debugger."""
-
- ONLINE_TYPE = "ONLINE"
- MAX_SESSION_NUM = 2
- ONLINE_SESSION_ID = "0"
- _instance = None
- _cls_lock = threading.Lock()
-
- def __init__(self):
- self.train_jobs = {}
- self.sessions = {}
- self.session_id = 1
- self.online_session = None
- self._lock = threading.Lock()
- self._exiting = False
- enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
- if enable_debugger:
- self.creat_session(self.ONLINE_TYPE)
-
- @classmethod
- def get_instance(cls):
- """Get the singleton instance."""
- with cls._cls_lock:
- if cls._instance is None:
- cls._instance = SessionManager()
- return cls._instance
-
- def exit(self):
- """
- Called when the gunicorn worker process is exiting.
- """
- with self._lock:
- logger.info("Start to exit sessions.")
- self._exiting = True
- for session in self.sessions:
- session.stop()
- self.online_session.stop()
- logger.info("Exited.")
-
- def get_session(self, session_id):
- """
- Get session by session id or get all session info.
-
- Args:
- session_id (Union[None, str]: The id of session.
-
- Returns:
- DebuggerSession, debugger session object.
- """
- with self._lock:
- if session_id == self.ONLINE_SESSION_ID and self.online_session is not None:
- return self.online_session
-
- if session_id in self.sessions:
- return self.sessions.get(session_id)
-
- raise DebuggerSessionNotFoundError("{}".format(session_id))
-
- def creat_session(self, session_type, train_job=None):
- """
- Create session by the train job info.
-
- Args:
- session_type (str): The session_type.
- train_job (str): The train job info.
-
- Returns:
- str, session id.
- """
- with self._lock:
- if self._exiting:
- logger.info(
- "System is exiting, will terminate the thread.")
- _thread.exit()
-
- if session_type == self.ONLINE_TYPE:
- if self.online_session is None:
- context = DebuggerServerContext(dbg_mode='online')
- self.online_session = DebuggerSession(context)
- self.online_session.start()
- return self.ONLINE_SESSION_ID
-
- if train_job in self.train_jobs:
- return self.train_jobs.get(train_job)
-
- self._check_session_num()
- summary_base_dir = settings.SUMMARY_BASE_DIR
- unquote_path = unquote(train_job, errors='strict')
- whole_path = os.path.join(summary_base_dir, unquote_path)
- normalized_path = validate_and_normalize_path(whole_path)
- context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path)
- session = DebuggerSession(context)
- session.start()
- session_id = str(self.session_id)
- self.sessions[session_id] = session
- self.train_jobs[train_job] = session_id
- self.session_id += 1
- return session_id
-
- def delete_session(self, session_id):
- """Delete session by session id."""
- with self._lock:
- if session_id == self.ONLINE_SESSION_ID:
- self.online_session.stop()
- self.online_session = None
- return
-
- if session_id not in self.sessions:
- raise DebuggerSessionNotFoundError("session id {}".format(session_id))
-
- session = self.sessions.get(session_id)
- session.stop()
- self.sessions.pop(session_id)
- self.train_jobs.pop(session.train_job)
- return
-
- def get_sessions(self):
- """get all sessions"""
- return {"train_jobs": self.train_jobs}
-
- def _check_session_num(self):
- """Check the amount of sessions."""
- if len(self.sessions) >= self.MAX_SESSION_NUM:
- raise DebuggerSessionNumOverBoundError()
-
-
- def validate_and_normalize_path(path):
- """Validate and normalize_path"""
- if not path:
- raise ValueError("The path is invalid!")
-
- path_str = str(path)
-
- if not path_str.startswith("/"):
- raise ValueError("The path is invalid!")
-
- try:
- normalized_path = os.path.realpath(path)
- except ValueError:
- raise ValueError("The path is invalid!")
-
- return normalized_path
|