You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_manager.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the session manager."""
  16. import os
  17. import threading
  18. from urllib.parse import unquote
  19. import _thread
  20. from mindinsight.conf import settings
  21. from mindinsight.debugger.common.log import LOGGER as logger
  22. from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \
  23. DebuggerSessionNotFoundError
  24. from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext
  25. from mindinsight.debugger.debugger_session import DebuggerSession
  26. class SessionManager:
  27. """The server manager of debugger."""
  28. ONLINE_TYPE = "ONLINE"
  29. MAX_SESSION_NUM = 2
  30. ONLINE_SESSION_ID = "0"
  31. _instance = None
  32. _cls_lock = threading.Lock()
  33. def __init__(self):
  34. self.train_jobs = {}
  35. self.sessions = {}
  36. self.session_id = 1
  37. self.online_session = None
  38. self._lock = threading.Lock()
  39. self._exiting = False
  40. enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
  41. if enable_debugger:
  42. self.creat_session(self.ONLINE_TYPE)
  43. @classmethod
  44. def get_instance(cls):
  45. """Get the singleton instance."""
  46. with cls._cls_lock:
  47. if cls._instance is None:
  48. cls._instance = SessionManager()
  49. return cls._instance
  50. def exit(self):
  51. """
  52. Called when the gunicorn worker process is exiting.
  53. """
  54. with self._lock:
  55. logger.info("Start to exit sessions.")
  56. self._exiting = True
  57. for session in self.sessions:
  58. session.stop()
  59. self.online_session.stop()
  60. logger.info("Exited.")
  61. def get_session(self, session_id):
  62. """
  63. Get session by session id or get all session info.
  64. Args:
  65. session_id (Union[None, str]: The id of session.
  66. Returns:
  67. DebuggerSession, debugger session object.
  68. """
  69. with self._lock:
  70. if session_id == self.ONLINE_SESSION_ID and self.online_session is not None:
  71. return self.online_session
  72. if session_id in self.sessions:
  73. return self.sessions.get(session_id)
  74. raise DebuggerSessionNotFoundError("{}".format(session_id))
  75. def creat_session(self, session_type, train_job=None):
  76. """
  77. Create session by the train job info.
  78. Args:
  79. session_type (str): The session_type.
  80. train_job (str): The train job info.
  81. Returns:
  82. str, session id.
  83. """
  84. with self._lock:
  85. if self._exiting:
  86. logger.info(
  87. "System is exiting, will terminate the thread.")
  88. _thread.exit()
  89. if session_type == self.ONLINE_TYPE:
  90. if self.online_session is None:
  91. context = DebuggerServerContext(dbg_mode='online')
  92. self.online_session = DebuggerSession(context)
  93. self.online_session.start()
  94. return self.ONLINE_SESSION_ID
  95. if train_job in self.train_jobs:
  96. return self.train_jobs.get(train_job)
  97. self._check_session_num()
  98. summary_base_dir = settings.SUMMARY_BASE_DIR
  99. unquote_path = unquote(train_job, errors='strict')
  100. whole_path = os.path.join(summary_base_dir, unquote_path)
  101. normalized_path = validate_and_normalize_path(whole_path)
  102. context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path)
  103. session = DebuggerSession(context)
  104. session.start()
  105. session_id = str(self.session_id)
  106. self.sessions[session_id] = session
  107. self.train_jobs[train_job] = session_id
  108. self.session_id += 1
  109. return session_id
  110. def delete_session(self, session_id):
  111. """Delete session by session id."""
  112. with self._lock:
  113. if session_id == self.ONLINE_SESSION_ID:
  114. self.online_session.stop()
  115. self.online_session = None
  116. return
  117. if session_id not in self.sessions:
  118. raise DebuggerSessionNotFoundError("session id {}".format(session_id))
  119. session = self.sessions.get(session_id)
  120. session.stop()
  121. self.sessions.pop(session_id)
  122. self.train_jobs.pop(session.train_job)
  123. return
  124. def get_sessions(self):
  125. """get all sessions"""
  126. return {"train_jobs": self.train_jobs}
  127. def _check_session_num(self):
  128. """Check the amount of sessions."""
  129. if len(self.sessions) >= self.MAX_SESSION_NUM:
  130. raise DebuggerSessionNumOverBoundError()
  131. def validate_and_normalize_path(path):
  132. """Validate and normalize_path"""
  133. if not path:
  134. raise ValueError("The path is invalid!")
  135. path_str = str(path)
  136. if not path_str.startswith("/"):
  137. raise ValueError("The path is invalid!")
  138. try:
  139. normalized_path = os.path.realpath(path)
  140. except ValueError:
  141. raise ValueError("The path is invalid!")
  142. return normalized_path