You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_manager.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the session manager."""
  16. import os
  17. import threading
  18. from urllib.parse import unquote
  19. import _thread
  20. from mindinsight.conf import settings
  21. from mindinsight.debugger.common.log import LOGGER as logger
  22. from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \
  23. DebuggerSessionNotFoundError, DebuggerOnlineSessionUnavailable
  24. from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext
  25. from mindinsight.debugger.debugger_session import DebuggerSession
  26. class SessionManager:
  27. """The server manager of debugger."""
  28. ONLINE_TYPE = "ONLINE"
  29. MAX_OFFLINE_SESSION_NUM = 2
  30. ONLINE_SESSION_ID = "0"
  31. _instance = None
  32. _cls_lock = threading.Lock()
  33. def __init__(self):
  34. self.train_jobs = {}
  35. self.sessions = {}
  36. # The offline session id is start from 1, and the online session id is 0.
  37. self._next_session_id = 1
  38. self._lock = threading.Lock()
  39. self._exiting = False
  40. enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
  41. if enable_debugger:
  42. self._create_online_session()
  43. @classmethod
  44. def get_instance(cls):
  45. """Get the singleton instance."""
  46. with cls._cls_lock:
  47. if cls._instance is None:
  48. cls._instance = SessionManager()
  49. return cls._instance
  50. def exit(self):
  51. """Called when the gunicorn worker process is exiting."""
  52. with self._lock:
  53. logger.info("Start to exit sessions.")
  54. self._exiting = True
  55. for session in self.sessions.values():
  56. session.stop()
  57. logger.info("Sessions exited.")
  58. def get_session(self, session_id):
  59. """
  60. Get session by session id or get all session info.
  61. Args:
  62. session_id (Union[None, str]): The id of session.
  63. Returns:
  64. DebuggerSession, debugger session object.
  65. """
  66. with self._lock:
  67. if session_id in self.sessions:
  68. return self.sessions.get(session_id)
  69. logger.error('Debugger session %s is not found.', session_id)
  70. raise DebuggerSessionNotFoundError("{}".format(session_id))
  71. def _create_online_session(self):
  72. """Create online session."""
  73. with self._lock:
  74. context = DebuggerServerContext(dbg_mode='online')
  75. online_session = DebuggerSession(context)
  76. online_session.start()
  77. self.sessions[self.ONLINE_SESSION_ID] = online_session
  78. def _create_offline_session(self, train_job):
  79. """Create offline session."""
  80. self._check_session_num()
  81. if not isinstance(train_job, str):
  82. logger.error('The train job path should be string.')
  83. raise ValueError("The train job path should be string.")
  84. summary_base_dir = settings.SUMMARY_BASE_DIR
  85. unquote_path = unquote(train_job, errors='strict')
  86. whole_path = os.path.join(summary_base_dir, unquote_path)
  87. normalized_path = validate_and_normalize_path(whole_path)
  88. context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path)
  89. session = DebuggerSession(context)
  90. session.start()
  91. session_id = str(self._next_session_id)
  92. self.sessions[session_id] = session
  93. self.train_jobs[train_job] = session_id
  94. self._next_session_id += 1
  95. return session_id
  96. def create_session(self, session_type, train_job=None):
  97. """
  98. Create the session by the train job info or session type if the session doesn't exist.
  99. Args:
  100. session_type (str): The session_type.
  101. train_job (str): The train job info.
  102. Returns:
  103. str, session id.
  104. """
  105. with self._lock:
  106. if self._exiting:
  107. logger.info(
  108. "System is exiting, will terminate the thread.")
  109. _thread.exit()
  110. if session_type == self.ONLINE_TYPE:
  111. if self.ONLINE_SESSION_ID not in self.sessions:
  112. logger.error(
  113. 'Online session is unavailable, set --enable-debugger as true/1 to enable debugger '
  114. 'when start Mindinsight server.')
  115. raise DebuggerOnlineSessionUnavailable()
  116. return self.ONLINE_SESSION_ID
  117. if train_job in self.train_jobs:
  118. return self.train_jobs.get(train_job)
  119. return self._create_offline_session(train_job)
  120. def delete_session(self, session_id):
  121. """Delete session by session id."""
  122. with self._lock:
  123. if session_id == self.ONLINE_SESSION_ID:
  124. logger.error('Online session can not be deleted.')
  125. raise ValueError("Online session can not be delete.")
  126. if session_id not in self.sessions:
  127. logger.error('Debugger session %s is not found', session_id)
  128. raise DebuggerSessionNotFoundError("session id {}".format(session_id))
  129. session = self.sessions.get(session_id)
  130. session.stop()
  131. self.sessions.pop(session_id)
  132. self.train_jobs.pop(session.train_job)
  133. return
  134. def get_train_jobs(self):
  135. """Get all train jobs."""
  136. return {"train_jobs": self.train_jobs}
  137. def _check_session_num(self):
  138. """Check the amount of sessions."""
  139. session_limitation = self.MAX_OFFLINE_SESSION_NUM
  140. if self.ONLINE_SESSION_ID in self.sessions:
  141. session_limitation += 1
  142. if len(self.sessions) >= session_limitation:
  143. logger.warning('Offline debugger session num %s is reach the limitation %s', len(self.sessions),
  144. session_limitation)
  145. raise DebuggerSessionNumOverBoundError()
  146. def validate_and_normalize_path(path):
  147. """Validate and normalize_path"""
  148. if not path:
  149. logger.error('The whole path of dump directory is None.')
  150. raise ValueError("The whole path of dump directory is None.")
  151. path_str = str(path)
  152. if not path_str.startswith("/"):
  153. logger.error('The whole path of dump directory is not start with \'/\'')
  154. raise ValueError("The whole path of dump directory is not start with \'/\'")
  155. try:
  156. normalized_path = os.path.realpath(path)
  157. except ValueError:
  158. logger.error('The whole path of dump directory is invalid.')
  159. raise ValueError("The whole path of dump directory is invalid.")
  160. return normalized_path