Browse Source

add step validation

tags/v1.1.0
yelihua 5 years ago
parent
commit
f0e733039b
4 changed files with 59 additions and 12 deletions
  1. +2
    -0
      mindinsight/debugger/common/exceptions/error_code.py
  2. +10
    -0
      mindinsight/debugger/common/exceptions/exceptions.py
  3. +45
    -10
      mindinsight/debugger/debugger_server.py
  4. +2
    -2
      tests/ut/debugger/test_debugger_server.py

+ 2
- 0
mindinsight/debugger/common/exceptions/error_code.py View File

@@ -28,6 +28,8 @@ class DebuggerErrors(DebuggerErrorCodes):
PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK
PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK

STEP_NUM_ERROR = 2 | _PARAM_ERROR_MASK

NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR
GRAPH_NOT_EXIST_ERROR = 1 | _DEBUGGER_GRAPH_ERROR



+ 10
- 0
mindinsight/debugger/common/exceptions/exceptions.py View File

@@ -136,3 +136,13 @@ class DebuggerGraphNotExistError(MindInsightException):
message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value,
http_code=400
)


class DebuggerStepNumError(MindInsightException):
"""The graph does not exist."""
def __init__(self):
super(DebuggerStepNumError, self).__init__(
error=DebuggerErrors.STEP_NUM_ERROR,
message="The type of step number should be int32.",
http_code=400
)

+ 45
- 10
mindinsight/debugger/debugger_server.py View File

@@ -26,7 +26,7 @@ from mindinsight.datavisual.utils.tools import to_float
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \
DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \
DebuggerCompareTensorError, DebuggerRecheckError
DebuggerCompareTensorError, DebuggerRecheckError, DebuggerStepNumError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
create_view_event_from_tensor_history, Streams, is_scope_type, RunLevel
@@ -42,6 +42,8 @@ from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR

class DebuggerServer:
"""The server manager of debugger."""
# max step number should be less than int32
_MAX_STEP_NUM = 2 ** 31 - 1

def __init__(self, grpc_port=None):
self.grpc_port = grpc_port
@@ -677,12 +679,14 @@ class DebuggerServer:
dict, metadata info.
"""
if metadata_stream.state != ServerStatus.WAITING.value:
self.cache_store.put_data(metadata_stream.get())
log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state)
raise DebuggerContinueError(
"MindSpore is not ready to run or is running currently."
)
metadata_stream.state = ServerStatus.RUNNING.value
try:
self._validate_continue_params(params)
event = self._construct_run_event(params)
self._send_watchpoints()
self.cache_store.put_command(event)
@@ -696,6 +700,41 @@ class DebuggerServer:
log.debug("Send the RunCMD to command queue.")
return metadata_stream.get(['state', 'enable_recheck'])

def _validate_continue_params(self, params):
"""
Validate continue params.

Args:
params (dict): The control params.

- level (str): The control granularity, `node`, `step` or `recheck` level.
Default: `step`.
- steps (int): Specify the steps that training should run.
Used when `level` is `step`.
- name (str): Specify the name of the node. Used when `level` is `node`.
- graph_name (str): The graph name.

Raises:
DebuggerParamValueError: Params are invalid.
"""
# validate level
level = params.get('level', 'step')
if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]:
log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level)
raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.")

# validate steps
step_num = params.get('steps', 1)
if not isinstance(step_num, int) or not (step_num == -1 or 0 < step_num <= self._MAX_STEP_NUM):
log.error("Invalid step value. Step number should be integer and in [1, 2^31 - 1] or -1.")
raise DebuggerStepNumError

# validate node name
if level == RunLevel.NODE.value:
node_name = params.get('name')
graph_name = params.get('graph_name')
self._validate_continue_node_name(node_name, graph_name)

def _construct_run_event(self, params):
"""
Construct run cmd from input control params.
@@ -714,22 +753,15 @@ class DebuggerServer:
EventReply, control event with run command.
"""
level = params.get('level', 'step')
# validate level
if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]:
log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level)
raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.")
# construct run command events
event = get_ack_reply()
if level == 'step':
steps = params.get('steps')
if not steps:
steps = 1
steps = params.get('steps', 1)
run_cmd = RunCMD(run_level='step', run_steps=steps)
elif level == 'node':
name = params.get('name', '')
graph_name = params.get('graph_name')
if name:
self._validate_leaf_name(name, graph_name)
name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name)
run_cmd = RunCMD(run_level='node', node_name=name)
else:
@@ -739,8 +771,10 @@ class DebuggerServer:
log.debug("Construct run event. %s", event)
return event

def _validate_leaf_name(self, node_name, graph_name):
def _validate_continue_node_name(self, node_name, graph_name):
"""Validate if the node is a leaf node."""
if not node_name:
return
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
node_type = graph_stream.get_node_type(node_name, graph_name)
if is_scope_type(node_type):
@@ -770,6 +804,7 @@ class DebuggerServer:
dict, metadata info.
"""
if metadata_stream.state != ServerStatus.RUNNING.value:
self.cache_store.put_data(metadata_stream.get())
log.error("The MindSpore is not running.")
raise DebuggerPauseError("The MindSpore is not running.")
metadata_stream.state = 'waiting'


+ 2
- 2
tests/ut/debugger/test_debugger_server.py View File

@@ -159,7 +159,7 @@ class TestDebuggerServer:
"""Test validate leaf name."""
args[0].return_value = 'name_scope'
with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'):
self._server._validate_leaf_name(node_name='mock_node_name', graph_name='mock_graph_name')
self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name')

@mock.patch.object(TensorHandler, 'get')
@mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name')
@@ -187,7 +187,7 @@ class TestDebuggerServer:
res = self._server._retrieve_watchpoint({'watch_point_id': 1})
assert res == mock_watchpoint

@mock.patch.object(DebuggerServer, '_validate_leaf_name')
@mock.patch.object(DebuggerServer, '_validate_continue_node_name')
@mock.patch.object(DebuggerServer, '_get_tensor_history')
@mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}})
def test_retrieve_watchpoint_hit(self, *args):


Loading…
Cancel
Save