|
|
|
@@ -26,7 +26,7 @@ from mindinsight.datavisual.utils.tools import to_float |
|
|
|
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ |
|
|
|
DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ |
|
|
|
DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \ |
|
|
|
DebuggerCompareTensorError, DebuggerRecheckError |
|
|
|
DebuggerCompareTensorError, DebuggerRecheckError, DebuggerStepNumError |
|
|
|
from mindinsight.debugger.common.log import LOGGER as log |
|
|
|
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ |
|
|
|
create_view_event_from_tensor_history, Streams, is_scope_type, RunLevel |
|
|
|
@@ -42,6 +42,8 @@ from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR |
|
|
|
|
|
|
|
class DebuggerServer: |
|
|
|
"""The server manager of debugger.""" |
|
|
|
# max step number should be less than int32 |
|
|
|
_MAX_STEP_NUM = 2 ** 31 - 1 |
|
|
|
|
|
|
|
def __init__(self, grpc_port=None): |
|
|
|
self.grpc_port = grpc_port |
|
|
|
@@ -677,12 +679,14 @@ class DebuggerServer: |
|
|
|
dict, metadata info. |
|
|
|
""" |
|
|
|
if metadata_stream.state != ServerStatus.WAITING.value: |
|
|
|
self.cache_store.put_data(metadata_stream.get()) |
|
|
|
log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) |
|
|
|
raise DebuggerContinueError( |
|
|
|
"MindSpore is not ready to run or is running currently." |
|
|
|
) |
|
|
|
metadata_stream.state = ServerStatus.RUNNING.value |
|
|
|
try: |
|
|
|
self._validate_continue_params(params) |
|
|
|
event = self._construct_run_event(params) |
|
|
|
self._send_watchpoints() |
|
|
|
self.cache_store.put_command(event) |
|
|
|
@@ -696,6 +700,41 @@ class DebuggerServer: |
|
|
|
log.debug("Send the RunCMD to command queue.") |
|
|
|
return metadata_stream.get(['state', 'enable_recheck']) |
|
|
|
|
|
|
|
def _validate_continue_params(self, params): |
|
|
|
""" |
|
|
|
Validate continue params. |
|
|
|
|
|
|
|
Args: |
|
|
|
params (dict): The control params. |
|
|
|
|
|
|
|
- level (str): The control granularity, `node`, `step` or `recheck` level. |
|
|
|
Default: `step`. |
|
|
|
- steps (int): Specify the steps that training should run. |
|
|
|
Used when `level` is `step`. |
|
|
|
- name (str): Specify the name of the node. Used when `level` is `node`. |
|
|
|
- graph_name (str): The graph name. |
|
|
|
|
|
|
|
Raises: |
|
|
|
DebuggerParamValueError: Params are invalid. |
|
|
|
""" |
|
|
|
# validate level |
|
|
|
level = params.get('level', 'step') |
|
|
|
if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]: |
|
|
|
log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level) |
|
|
|
raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.") |
|
|
|
|
|
|
|
# validate steps |
|
|
|
step_num = params.get('steps', 1) |
|
|
|
if not isinstance(step_num, int) or not (step_num == -1 or 0 < step_num <= self._MAX_STEP_NUM): |
|
|
|
log.error("Invalid step value. Step number should be integer and in [1, 2^31 - 1] or -1.") |
|
|
|
raise DebuggerStepNumError |
|
|
|
|
|
|
|
# validate node name |
|
|
|
if level == RunLevel.NODE.value: |
|
|
|
node_name = params.get('name') |
|
|
|
graph_name = params.get('graph_name') |
|
|
|
self._validate_continue_node_name(node_name, graph_name) |
|
|
|
|
|
|
|
def _construct_run_event(self, params): |
|
|
|
""" |
|
|
|
Construct run cmd from input control params. |
|
|
|
@@ -714,22 +753,15 @@ class DebuggerServer: |
|
|
|
EventReply, control event with run command. |
|
|
|
""" |
|
|
|
level = params.get('level', 'step') |
|
|
|
# validate level |
|
|
|
if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]: |
|
|
|
log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level) |
|
|
|
raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.") |
|
|
|
# construct run command events |
|
|
|
event = get_ack_reply() |
|
|
|
if level == 'step': |
|
|
|
steps = params.get('steps') |
|
|
|
if not steps: |
|
|
|
steps = 1 |
|
|
|
steps = params.get('steps', 1) |
|
|
|
run_cmd = RunCMD(run_level='step', run_steps=steps) |
|
|
|
elif level == 'node': |
|
|
|
name = params.get('name', '') |
|
|
|
graph_name = params.get('graph_name') |
|
|
|
if name: |
|
|
|
self._validate_leaf_name(name, graph_name) |
|
|
|
name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name) |
|
|
|
run_cmd = RunCMD(run_level='node', node_name=name) |
|
|
|
else: |
|
|
|
@@ -739,8 +771,10 @@ class DebuggerServer: |
|
|
|
log.debug("Construct run event. %s", event) |
|
|
|
return event |
|
|
|
|
|
|
|
def _validate_leaf_name(self, node_name, graph_name): |
|
|
|
def _validate_continue_node_name(self, node_name, graph_name): |
|
|
|
"""Validate if the node is a leaf node.""" |
|
|
|
if not node_name: |
|
|
|
return |
|
|
|
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) |
|
|
|
node_type = graph_stream.get_node_type(node_name, graph_name) |
|
|
|
if is_scope_type(node_type): |
|
|
|
@@ -770,6 +804,7 @@ class DebuggerServer: |
|
|
|
dict, metadata info. |
|
|
|
""" |
|
|
|
if metadata_stream.state != ServerStatus.RUNNING.value: |
|
|
|
self.cache_store.put_data(metadata_stream.get()) |
|
|
|
log.error("The MindSpore is not running.") |
|
|
|
raise DebuggerPauseError("The MindSpore is not running.") |
|
|
|
metadata_stream.state = 'waiting' |
|
|
|
|