Browse Source

fix sessions api and update outdated when jump step.

pull/1308/head
jiangshuqiang 4 years ago
parent
commit
483803bc03
7 changed files with 56 additions and 38 deletions
  1. +5
    -5
      mindinsight/backend/debugger/debugger_api.py
  2. +2
    -2
      mindinsight/debugger/common/exceptions/error_code.py
  3. +5
    -5
      mindinsight/debugger/common/exceptions/exceptions.py
  4. +37
    -24
      mindinsight/debugger/session_manager.py
  5. +4
    -0
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  6. +2
    -1
      mindinsight/debugger/stream_operator/training_control_operator.py
  7. +1
    -1
      tests/ut/debugger/stream_operator/test_training_control_operator.py

+ 5
- 5
mindinsight/backend/debugger/debugger_api.py View File

@@ -375,7 +375,7 @@ def set_recommended_watch_points(session_id):


@BLUEPRINT.route("/debugger/sessions", methods=["POST"])
def creat_session():
def create_session():
"""
Get session id if session exist, else create a session.

@@ -383,22 +383,22 @@ def creat_session():
str, session id.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/get-session
>>> POST http://xxxx/v1/mindinsight/debugger/sessions
"""
body = _read_post_request(request)
summary_dir = body.get('dump_dir')
session_type = body.get('session_type')
reply = _wrap_reply(SessionManager.get_instance().creat_session, session_type, summary_dir)
reply = _wrap_reply(SessionManager.get_instance().create_session, session_type, summary_dir)
return reply


@BLUEPRINT.route("/debugger/sessions", methods=["GET"])
def get_sessions():
def get_train_jobs():
"""
Check the current active sessions.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/check-sessions
>>> POST http://xxxx/v1/mindinsight/debugger/sessions
"""
reply = _wrap_reply(SessionManager.get_instance().get_train_jobs)
return reply


+ 2
- 2
mindinsight/debugger/common/exceptions/error_code.py View File

@@ -52,7 +52,7 @@ class DebuggerErrors(DebuggerErrorCodes):

DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR
DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR
DEBUGGER_SESSION_ALREADY_EXIST_ERROR = 2 | _DEBUGGER_SESSION_ERROR
DEBUGGER_ONLINE_SESSION_UNAVAILABLE = 2 | _DEBUGGER_SESSION_ERROR


@unique
@@ -80,4 +80,4 @@ class DebuggerErrorMsg(Enum):

DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation."
DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found."
DEBUGGER_SESSION_ALREADY_EXIST_ERROR = "Session {} already exist."
DEBUGGER_ONLINE_SESSION_UNAVAILABLE = "Online session is unavailable."

+ 5
- 5
mindinsight/debugger/common/exceptions/exceptions.py View File

@@ -247,12 +247,12 @@ class DebuggerSessionNotFoundError(MindInsightException):
)


class DebuggerSessionAlreadyExistError(MindInsightException):
class DebuggerOnlineSessionUnavailable(MindInsightException):
"""The error of that the session already exist."""

def __init__(self, msg):
super(DebuggerSessionAlreadyExistError, self).__init__(
error=DebuggerErrors.DEBUGGER_SESSION_ALREADY_EXIST_ERROR,
message=DebuggerErrorMsg.DEBUGGER_SESSION_ALREADY_EXIST_ERROR.value.format(msg),
def __init__(self):
super(DebuggerOnlineSessionUnavailable, self).__init__(
error=DebuggerErrors.DEBUGGER_ONLINE_SESSION_UNAVAILABLE,
message=DebuggerErrorMsg.DEBUGGER_ONLINE_SESSION_UNAVAILABLE.value,
http_code=400
)

+ 37
- 24
mindinsight/debugger/session_manager.py View File

@@ -22,7 +22,7 @@ import _thread
from mindinsight.conf import settings
from mindinsight.debugger.common.log import LOGGER as logger
from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \
DebuggerSessionNotFoundError, DebuggerSessionAlreadyExistError
DebuggerSessionNotFoundError, DebuggerOnlineSessionUnavailable
from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext
from mindinsight.debugger.debugger_session import DebuggerSession

@@ -45,7 +45,7 @@ class SessionManager:
self._exiting = False
enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
if enable_debugger:
self.creat_session(self.ONLINE_TYPE)
self._create_online_session()

@classmethod
def get_instance(cls):
@@ -81,9 +81,36 @@ class SessionManager:
logger.error('Debugger session %s is not found.', session_id)
raise DebuggerSessionNotFoundError("{}".format(session_id))

def creat_session(self, session_type, train_job=None):
def _create_online_session(self):
"""Create online session."""
with self._lock:
context = DebuggerServerContext(dbg_mode='online')
online_session = DebuggerSession(context)
online_session.start()
self.sessions[self.ONLINE_SESSION_ID] = online_session

def _create_offline_session(self, train_job):
"""Create offline session."""
self._check_session_num()
if not isinstance(train_job, str):
logger.error('The train job path should be string.')
raise ValueError("The train job path should be string.")
summary_base_dir = settings.SUMMARY_BASE_DIR
unquote_path = unquote(train_job, errors='strict')
whole_path = os.path.join(summary_base_dir, unquote_path)
normalized_path = validate_and_normalize_path(whole_path)
context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path)
session = DebuggerSession(context)
session.start()
session_id = str(self._next_session_id)
self.sessions[session_id] = session
self.train_jobs[train_job] = session_id
self._next_session_id += 1
return session_id

def create_session(self, session_type, train_job=None):
"""
Create session by the train job info.
Create the session by the train job info or session type if the session doesn't exist.

Args:
session_type (str): The session_type.
@@ -100,30 +127,16 @@ class SessionManager:

if session_type == self.ONLINE_TYPE:
if self.ONLINE_SESSION_ID not in self.sessions:
context = DebuggerServerContext(dbg_mode='online')
online_session = DebuggerSession(context)
online_session.start()
self.sessions[self.ONLINE_SESSION_ID] = online_session
return self.ONLINE_SESSION_ID
logger.error('Online session is already exist, cannot create online session twice.')
raise DebuggerSessionAlreadyExistError("online session")
logger.error(
'Online session is unavailable, set --enable-debugger as true/1 to enable debugger '
'when start Mindinsight server.')
raise DebuggerOnlineSessionUnavailable()
return self.ONLINE_SESSION_ID

if train_job in self.train_jobs:
return self.train_jobs.get(train_job)

self._check_session_num()
summary_base_dir = settings.SUMMARY_BASE_DIR
unquote_path = unquote(train_job, errors='strict')
whole_path = os.path.join(summary_base_dir, unquote_path)
normalized_path = validate_and_normalize_path(whole_path)
context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path)
session = DebuggerSession(context)
session.start()
session_id = str(self._next_session_id)
self.sessions[session_id] = session
self.train_jobs[train_job] = session_id
self._next_session_id += 1
return session_id
return self._create_offline_session(train_job)

def delete_session(self, session_id):
"""Delete session by session id."""


+ 4
- 0
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -46,6 +46,10 @@ class WatchpointHandler(StreamHandlerBase):
# whether the watchpoint list has been changed since last step
self._outdated = False

def set_outdated(self):
""""Set outdated as True."""
self._outdated = True

def put(self, value):
"""
Put Watchpoint into watchpoint handler.


+ 2
- 1
mindinsight/debugger/stream_operator/training_control_operator.py View File

@@ -306,7 +306,8 @@ class TrainingControlOperator:
self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id)
self._cache_store.clean_data()
self._cache_store.clean_command()
metadata_stream.enable_recheck = False
metadata_stream.enable_recheck = True
metadata_stream.state = ServerStatus.WAITING.value
self._cache_store.get_stream_handler(Streams.WATCHPOINT).set_outdated()
log.debug("Send the Change_training_step CMD.")
return metadata_stream.get(['state', 'enable_recheck', 'step'])

+ 1
- 1
tests/ut/debugger/stream_operator/test_training_control_operator.py View File

@@ -73,4 +73,4 @@ class TestTrainingControlOperator:
with mock.patch.object(MetadataHandler, 'max_step_num', 10), \
mock.patch.object(MetadataHandler, 'debugger_type', 'offline'):
res = self._server.control(mode=mode, params={'steps': 9})
assert res == {'metadata': {'enable_recheck': False, 'state': state, 'step': 9}}
assert res == {'metadata': {'enable_recheck': True, 'state': state, 'step': 9}}

Loading…
Cancel
Save