Browse Source

refactor control api && auto updating tensor value for tensor graphs

tags/v1.1.0
yelihua 5 years ago
parent
commit
e1e7580362
20 changed files with 599 additions and 384 deletions
  1. +4
    -1
      mindinsight/debugger/common/exceptions/error_code.py
  2. +22
    -0
      mindinsight/debugger/common/exceptions/exceptions.py
  3. +8
    -8
      mindinsight/debugger/common/utils.py
  4. +23
    -10
      mindinsight/debugger/debugger_grpc_server.py
  5. +24
    -218
      mindinsight/debugger/debugger_server.py
  6. +16
    -19
      mindinsight/debugger/stream_cache/tensor.py
  7. +1
    -2
      mindinsight/debugger/stream_handler/graph_handler.py
  8. +78
    -69
      mindinsight/debugger/stream_handler/tensor_handler.py
  9. +30
    -4
      mindinsight/debugger/stream_operator/tensor_detail_info.py
  10. +273
    -0
      mindinsight/debugger/stream_operator/training_control_operator.py
  11. +2
    -0
      mindinsight/utils/tensor.py
  12. +0
    -21
      tests/st/func/debugger/expect_results/restful_results/compare_tensors.json
  13. +29
    -1
      tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json
  14. +1
    -1
      tests/st/func/debugger/mock_ms_client.py
  15. +1
    -0
      tests/ut/debugger/stream_handler/test_metadata_handler.py
  16. +1
    -1
      tests/ut/debugger/stream_handler/test_tensor_handler.py
  17. +15
    -0
      tests/ut/debugger/stream_operator/__init__.py
  18. +66
    -0
      tests/ut/debugger/stream_operator/test_training_control_operator.py
  19. +5
    -5
      tests/ut/debugger/test_debugger_grpc_server.py
  20. +0
    -24
      tests/ut/debugger/test_debugger_server.py

+ 4
- 1
mindinsight/debugger/common/exceptions/error_code.py View File

@@ -27,7 +27,6 @@ class DebuggerErrors(DebuggerErrorCodes):
"""Debugger error codes."""
PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK
PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK

STEP_NUM_ERROR = 2 | _PARAM_ERROR_MASK

NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR
@@ -40,6 +39,8 @@ class DebuggerErrors(DebuggerErrorCodes):
PAUSE_ERROR = 4 | _DEBUGGER_RUNNING_ERROR
COMPARE_TENSOR_ERROR = 5 | _DEBUGGER_RUNNING_ERROR
RECHECK_ERROR = 6 | _DEBUGGER_RUNNING_ERROR
TENSOR_GRAPH_ERROR = 7 | _DEBUGGER_RUNNING_ERROR
TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR


@unique
@@ -56,3 +57,5 @@ class DebuggerErrorMsg(Enum):
CONTINUE_ERROR = "Continue debugging failed. {}"
PAUSE_ERROR = "Pause debugging failed. {}"
RECHECK_ERROR = "Recheck failed. {}"
TENSOR_GRAPH_ERROR = "Get tensor graphs failed."
TENSOR_HIT_ERROR = "Get tensor hits failed."

+ 22
- 0
mindinsight/debugger/common/exceptions/exceptions.py View File

@@ -146,3 +146,25 @@ class DebuggerStepNumError(MindInsightException):
message="The type of step number should be int32.",
http_code=400
)


class DebuggerTensorGraphError(MindInsightException):
"""The error about comparing tensors."""

def __init__(self):
super(DebuggerTensorGraphError, self).__init__(
error=DebuggerErrors.TENSOR_GRAPH_ERROR,
message=DebuggerErrorMsg.TENSOR_GRAPH_ERROR.value,
http_code=400
)


class DebuggerTensorHitError(MindInsightException):
"""The error about comparing tensors."""

def __init__(self):
super(DebuggerTensorHitError, self).__init__(
error=DebuggerErrors.TENSOR_HIT_ERROR,
message=DebuggerErrorMsg.TENSOR_HIT_ERROR.value,
http_code=400
)

+ 8
- 8
mindinsight/debugger/common/utils.py View File

@@ -115,29 +115,29 @@ def wrap_reply_response(error_code=None, error_message=None):
return reply


def create_view_event_from_tensor_history(tensor_history):
def create_view_event_from_tensor_basic_info(tensors_info):
"""
Create view event reply according to tensor names.

Args:
tensor_history (list[dict]): The list of tensor history. Each element has keys:
`name`, `node_type`.
tensors_info (list[TensorBasicInfo]): The list of TensorBasicInfo. Each element has keys:
`full_name`, `node_type`, `iter`.

Returns:
EventReply, the event reply with view cmd.
"""
view_event = get_ack_reply()
for tensor_info in tensor_history:
node_type = tensor_info.get('node_type')
for tensor_info in tensors_info:
node_type = tensor_info.node_type
if node_type == NodeTypeEnum.CONST.value:
continue
truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value
tensor_name = tensor_info.get('full_name', '')
truncate_tag = node_type == NodeTypeEnum.PARAMETER.value
tensor_name = tensor_info.full_name
# create view command
ms_tensor = view_event.view_cmd.tensors.add()
ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1)
ms_tensor.truncate = truncate_tag
ms_tensor.iter = 'prev' if tensor_info.get('iter') else ''
ms_tensor.iter = tensor_info.iter

return view_event



+ 23
- 10
mindinsight/debugger/debugger_grpc_server.py View File

@@ -159,15 +159,15 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):

def _send_received_tensor_tag(self):
"""Send received_finish_tag."""
node_name = self._received_view_cmd.get('node_name')
if not node_name or self._received_view_cmd.get('wait_for_tensor'):
node_info = self._received_view_cmd.get('node_info')
if not node_info or self._received_view_cmd.get('wait_for_tensor'):
return
metadata = self._cache_store.get_stream_handler(Streams.METADATA).get(['step', 'state'])
ret = {'receive_tensor': {'node_name': node_name}}
ret = {'receive_tensor': node_info.copy()}
ret.update(metadata)
self._cache_store.put_data(ret)
self._received_view_cmd.clear()
log.debug("Send receive tensor flag for %s", node_name)
log.debug("Send receive tensor flag for %s", node_info)

def _send_watchpoint_hit_flag(self):
"""Send Watchpoint hit flag."""
@@ -281,14 +281,26 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
return event

def _deal_with_view_cmd(self, event):
"""Deal with view cmd."""
view_cmd = event.get('view_cmd')
node_name = event.get('node_name')
log.debug("Receive view cmd for node: %s.", node_name)
if not (view_cmd and node_name):
"""
Deal with view cmd.

Args:
event (dict): View command params.

- view_cmd (EventReply): EventReply with view command.
- node_name (str): The center node name for view command.
- tensor_name (str): The center tensor name for view command.
- graph_name (str): The graph name of center node.

Returns:
EventReply, view command to be sent to client.
"""
view_cmd = event.pop('view_cmd', None)
log.debug("Receive view cmd for node: %s.", event)
if not (view_cmd and event):
log.debug("Invalid view command. Ignore it.")
return None
self._received_view_cmd['node_name'] = node_name
self._received_view_cmd['node_info'] = event
self._received_view_cmd['wait_for_tensor'] = True
return view_cmd

@@ -395,6 +407,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
if tensor.finished:
update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
if self._received_view_cmd.get('wait_for_tensor') and update_flag:
# update_flag is used to avoid querying empty tensors again
self._received_view_cmd['wait_for_tensor'] = False
log.debug("Set wait for tensor flag to False.")
tensor_construct = []


+ 24
- 218
mindinsight/debugger/debugger_server.py View File

@@ -16,35 +16,33 @@
import signal
from concurrent import futures
from threading import Thread

import grpc

from mindinsight.conditionmgr.conditionmgr import ConditionMgr
from mindinsight.conditionmgr.common.utils import NodeBasicInfo
from mindinsight.conditionmgr.condition import ConditionContext, ConditionIdEnum
from mindinsight.conditionmgr.conditionmgr import ConditionMgr
from mindinsight.conditionmgr.recommender import recommend_watchpoints
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.utils.tools import to_float
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \
DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \
DebuggerCompareTensorError, DebuggerRecheckError, DebuggerStepNumError
DebuggerDeleteWatchPointError, DebuggerCompareTensorError, DebuggerTensorGraphError, \
DebuggerTensorHitError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
create_view_event_from_tensor_history, Streams, is_scope_type, RunLevel
from mindinsight.conditionmgr.common.utils import NodeBasicInfo
from mindinsight.debugger.common.utils import ServerStatus, \
create_view_event_from_tensor_basic_info, Streams
from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo
from mindinsight.utils.exceptions import MindInsightException
from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator
from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR


class DebuggerServer:
"""The server manager of debugger."""
# max step number should be less than int32
_MAX_STEP_NUM = 2 ** 31 - 1

def __init__(self, grpc_port=None):
self.grpc_port = grpc_port
@@ -355,7 +353,7 @@ class DebuggerServer:
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
tensor_history = graph_stream.get_tensor_history(node_name, graph_name)
# add tensor value for tensor history
self._add_tensor_value_for_tensor_history(tensor_history, node_name)
self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name)
# add hit label for tensor history
watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
watchpoint_hit_stream.update_tensor_history(tensor_history)
@@ -364,13 +362,14 @@ class DebuggerServer:
tensor_history.update(metadata)
return tensor_history

def _add_tensor_value_for_tensor_history(self, tensor_history, node_name):
def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name):
"""
Add tensor value for_tensor_history and send ViewCMD if tensor value missed.

Args:
tensor_history (list[dict]): A list of tensor info, including name and type.
node_name (str): The UI node name.
graph_name (str): The graph name. Default: None.

Returns:
dict, the tensor info.
@@ -378,8 +377,8 @@ class DebuggerServer:
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
missed_tensors = tensor_stream.update_tensor_history(tensor_history)
if missed_tensors:
view_cmd = create_view_event_from_tensor_history(missed_tensors)
self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name})
view_cmd = create_view_event_from_tensor_basic_info(missed_tensors)
self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name})
log.debug("Send view cmd.")

def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False):
@@ -679,189 +678,10 @@ class DebuggerServer:
dict, the response.
"""
log.info("Receive control request: %s.", params)
mode = params.get('mode')
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if mode == 'continue':
reply = self._continue(metadata_stream, params)
elif mode in ['pause', 'terminate']:
mode_mapping = {
'pause': self._pause,
'terminate': self._terminate
}
reply = mode_mapping.get(mode)(metadata_stream)
else:
log.error("Invalid control mode %s", mode)
raise DebuggerParamValueError("Invalid control mode.")

return reply

def _continue(self, metadata_stream, params):
"""
Send RunCMD to MindSpore.

Args:
metadata_stream (MetadataHandler): The metadata_handler
params (dict): The control params.

Returns:
dict, metadata info.
"""
if metadata_stream.state != ServerStatus.WAITING.value:
self.cache_store.put_data(metadata_stream.get())
log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state)
raise DebuggerContinueError(
"MindSpore is not ready to run or is running currently."
)
metadata_stream.state = ServerStatus.RUNNING.value
try:
self._validate_continue_params(params)
event = self._construct_run_event(params)
self._send_watchpoints()
self.cache_store.put_command(event)
except MindInsightException as err:
log.error("Failed to send run event.")
log.exception(err)
metadata_stream.state = ServerStatus.WAITING.value
raise DebuggerContinueError("Failed to send run command.")
else:
metadata_stream.enable_recheck = False
log.debug("Send the RunCMD to command queue.")
return metadata_stream.get(['state', 'enable_recheck'])

def _validate_continue_params(self, params):
"""
Validate continue params.

Args:
params (dict): The control params.

- level (str): The control granularity, `node`, `step` or `recheck` level.
Default: `step`.
- steps (int): Specify the steps that training should run.
Used when `level` is `step`.
- name (str): Specify the name of the node. Used when `level` is `node`.
- graph_name (str): The graph name.

Raises:
DebuggerParamValueError: Params are invalid.
"""
# validate level
level = params.get('level', 'step')
if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]:
log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level)
raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.")

# validate steps
step_num = params.get('steps', 1)
if not isinstance(step_num, int) or not (step_num == -1 or 0 < step_num <= self._MAX_STEP_NUM):
log.error("Invalid step value. Step number should be integer and in [1, 2^31 - 1] or -1.")
raise DebuggerStepNumError

# validate node name
if level == RunLevel.NODE.value:
node_name = params.get('name')
graph_name = params.get('graph_name')
self._validate_continue_node_name(node_name, graph_name)

def _construct_run_event(self, params):
"""
Construct run cmd from input control params.

Args:
params (dict): The control params.

- level (str): The control granularity, `node`, `step` or `recheck` level.
Default: `step`.
- steps (int): Specify the steps that training should run.
Used when `level` is `step`.
- name (str): Specify the name of the node. Used when `level` is `node`.
- graph_name (str): The graph name.

Returns:
EventReply, control event with run command.
"""
level = params.get('level', 'step')
# construct run command events
event = get_ack_reply()
if level == 'step':
steps = params.get('steps', 1)
run_cmd = RunCMD(run_level='step', run_steps=steps)
elif level == 'node':
name = params.get('name', '')
graph_name = params.get('graph_name')
if name:
name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name)
run_cmd = RunCMD(run_level='node', node_name=name)
else:
run_cmd = RunCMD(run_level='recheck')

event.run_cmd.CopyFrom(run_cmd)
log.debug("Construct run event. %s", event)
return event

def _validate_continue_node_name(self, node_name, graph_name):
"""Validate if the node is a leaf node."""
if not node_name:
return
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
node_type = graph_stream.get_node_type(node_name, graph_name)
if is_scope_type(node_type):
log.error("Scope type node has no tensor history.")
raise DebuggerParamValueError("Invalid leaf node name.")

def _send_watchpoints(self):
"""Set watchpoints."""
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
set_commands = watchpoint_stream.get_pending_commands(self.cache_store.get_stream_handler(Streams.GRAPH))
if set_commands:
for set_cmd in set_commands:
event = get_ack_reply()
event.set_cmd.CopyFrom(set_cmd)
self.cache_store.put_command(event)
watchpoint_stream.sync_set_cmd(set_commands)
log.debug("Send SetCMD to MindSpore. %s", event)

def _pause(self, metadata_stream):
"""
Pause the training.

Args:
metadata_stream (MetadataHandler): The metadata stream handler.

Returns:
dict, metadata info.
"""
if metadata_stream.state != ServerStatus.RUNNING.value:
self.cache_store.put_data(metadata_stream.get())
log.error("The MindSpore is not running.")
raise DebuggerPauseError("The MindSpore is not running.")
metadata_stream.state = 'waiting'
event = get_ack_reply()
event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
self.cache_store.put_command(event)
metadata_stream.enable_recheck = False
log.debug("Send the Pause command")
return metadata_stream.get(['state', 'enable_recheck'])

def _terminate(self, metadata_stream):
"""
Terminate the training.

Args:
metadata_stream (MetadataHandler): The metadata stream handler.

Returns:
dict, metadata info.
"""
metadata_stream.state = 'pending'
self.cache_store.clean_data()
self.cache_store.clean_command()
event = get_ack_reply()
event.exit = True
self.cache_store.put_command(event)
metadata_stream.enable_recheck = False
log.debug("Send the ExitCMD.")
return metadata_stream.get(['state', 'enable_recheck'])
mode = params.pop('mode', None)
training_controller = TrainingControlOperator(self.cache_store)
training_controller.validate_mode(mode)
return training_controller.control(mode, params)

def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False):
"""
@@ -904,27 +724,7 @@ class DebuggerServer:
Returns:
dict, metadata info.
"""
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
# validate backend status is able to recheck watchpoint
if not metadata_stream.enable_recheck:
log.error("Recheck is not available.")
raise DebuggerRecheckError("Recheck is not available.")
metadata_stream.state = ServerStatus.RUNNING.value
metadata_stream.enable_recheck = False
# send updated watchpoint and recheck command
try:
event = self._construct_run_event({'level': 'recheck'})
self._send_watchpoints()
self.cache_store.put_command(event)
except MindInsightException as err:
log.error("Failed to send recheck event.")
log.exception(err)
metadata_stream.state = ServerStatus.WAITING.value
metadata_stream.enable_recheck = True
raise DebuggerContinueError("Failed to send run command.")
else:
log.debug("Send the recheck to command queue.")
return metadata_stream.get(['state', 'enable_recheck'])
return TrainingControlOperator(self.cache_store).recheck()

def retrieve_tensor_graph(self, tensor_name, graph_name):
"""
@@ -937,6 +737,9 @@ class DebuggerServer:
Returns:
dict, tensor graph object.
"""
if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value:
log.error("Failed to get tensor graph the MindSpore is not in waiting state.")
raise DebuggerTensorGraphError
log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name)
tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name)
return tensor_graph_ops
@@ -952,6 +755,9 @@ class DebuggerServer:
Returns:
dict, tensor hit info.
"""
if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value:
log.error("Failed to get tensor hits as the MindSpore is not in waiting state.")
raise DebuggerTensorHitError
log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name)
watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name)
return {'watch_points': watch_points}


+ 16
- 19
mindinsight/debugger/stream_cache/tensor.py View File

@@ -130,6 +130,16 @@ class OpTensor(BaseTensor):
"""The property of tensor stats."""
return self._stats

@stats.setter
def stats(self, stats):
"""
Update tensor stats.

Args:
stats (Statistics): Instance of Statistics.
"""
self._stats = stats

@property
def tensor_comparison(self):
"""The property of tensor_comparison."""
@@ -167,15 +177,10 @@ class OpTensor(BaseTensor):
res = {}
# the type of tensor_value is one of None, np.ndarray or str
if isinstance(tensor_value, np.ndarray):
statistics = TensorUtils.get_statistics_from_tensor(tensor_value)
if not self.stats:
self.update_tensor_stats(TensorUtils.get_statistics_from_tensor(self.value))
res['statistics'] = TensorUtils.get_statistics_dict(stats=statistics, overall_stats=self.stats)
res['value'] = tensor_value.tolist()
elif isinstance(tensor_value, str):
res['value'] = tensor_value
res['statistics'] = TensorUtils.get_overall_statistic_dict(self._stats)

res['statistics'] = self.get_tensor_statistics()
return res

def get_tensor_statistics(self):
@@ -185,9 +190,11 @@ class OpTensor(BaseTensor):
Returns:
dict, overall statistics.
"""
if not self._stats:
self._stats = TensorUtils.get_statistics_from_tensor(self.value)
statistics = TensorUtils.get_overall_statistic_dict(self._stats)
if self.empty:
return {}
if not self.stats:
self.stats = TensorUtils.get_statistics_from_tensor(self.value)
statistics = TensorUtils.get_overall_statistic_dict(self.stats)
return statistics

def update_tensor_comparisons(self, tensor_comparison):
@@ -200,16 +207,6 @@ class OpTensor(BaseTensor):
"""
self._tensor_comparison = tensor_comparison

def update_tensor_stats(self, stats):
"""
Update tensor stats.

Args:
stats (Statistics) instance of Statistics.

"""
self._stats = stats

def get_tensor_value_by_shape(self, shape=None):
"""
Get tensor value by shape.


+ 1
- 2
mindinsight/debugger/stream_handler/graph_handler.py View File

@@ -467,7 +467,7 @@ class GraphHandler(StreamHandlerBase):
Get tensor graph according to node name.

Args:
tensor_name (str): Tensor name, format is "node_name:<node_value>".
tensor_name (str): Tensor name from UI, format is "node_name:slot".
graph_name (str): The relative graph_name of the node. Default: None.

Returns:
@@ -624,7 +624,6 @@ class GraphHandler(StreamHandlerBase):
graph_name = self.graph_names[0]
return graph_name


def _add_graph_scope_for_nodes(self, nodes, graph_name):
"""
Add graph scope for nodes.


+ 78
- 69
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Define the tensor stream handler."""
from collections import namedtuple

import numpy as np

from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
@@ -23,6 +25,7 @@ from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
from mindinsight.utils.tensor import TensorUtils, TensorComparison

TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter'])

class TensorHandler(StreamHandlerBase):
"""Metadata Handler."""
@@ -170,7 +173,7 @@ class TensorHandler(StreamHandlerBase):
log.error("No tensor named %s at the step %s", name, step)
raise DebuggerParamValueError("No tensor named {}".format(name))
tensor_info = tensor.get_full_info(shape)
self._update_has_prev_step_field(tensor_info, name, node_type, step)
self._update_has_prev_step_field(tensor_info, name, node_type)
return {'tensor_value': tensor_info}

def _get_tensor(self, tensor_name, node_type=None, step=None):
@@ -219,35 +222,46 @@ class TensorHandler(StreamHandlerBase):
tensor_name = tensor_info.get('full_name')
node_type = tensor_info.get('node_type')
basic_info = self._get_basic_info(tensor_name, node_type)
flag = self._update_has_prev_step_field(basic_info, tensor_name, node_type, self.cur_step)
if flag is False:
missed_tensor = tensor_info.copy()
missed_tensor['iter'] = 'prev'
missed_tensors.append(missed_tensor)
log.debug("Add previous view cmd for %s", tensor_name)
# add `has_prev_step` field to tensor basic info.
missing_tensor_infos = self._update_has_prev_step_field(basic_info, tensor_name, node_type)
if basic_info:
tensor_info.update(basic_info)
if basic_info.get('value') is None:
missed_tensors.append(tensor_info)
log.debug("Add view cmd for %s", tensor_name)
else:
missed_tensors.append(tensor_info)
log.debug("Add view cmd for %s", tensor_name)
if missing_tensor_infos:
missed_tensors.extend(missing_tensor_infos)

return missed_tensors

def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type, step):
def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type):
"""Update has_prev_step field in tensor info."""
flag = None
cur_tensor_value = bool(tensor_info and tensor_info.get('value') is not None)
if node_type == NodeTypeEnum.PARAMETER.value:
flag = self._get_prev_tensor_value_status(tensor_name, step)
if flag and cur_tensor_value:
tensor_info['has_prev_step'] = True
return flag
missing_tensor_infos = self.get_missing_tensor_info(tensor_name, node_type)
if not missing_tensor_infos and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0:
tensor_info['has_prev_step'] = True
return missing_tensor_infos

def _get_prev_tensor_value_status(self, tensor_name, step):
def get_missing_tensor_info(self, tensor_name, node_type):
"""
Get missing tensor infos.

Args:
tensor_name (str): The full name of Tensor.
node_type (str): The type of the relative node.

Returns:
list, list of missing tensor basic information.
"""
step = self.cur_step
missing_tensor_infos = []
# check the current step value is missing
if self._is_tensor_value_missing(tensor_name, step):
missing_tensor_infos.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter=''))
log.debug("Add current step view cmd for %s", tensor_name)
# check the previous step value is missing
if node_type == NodeTypeEnum.PARAMETER.value and self._is_tensor_value_missing(tensor_name, step - 1):
missing_tensor_infos.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='prev'))
log.debug("Add previous view cmd for %s", tensor_name)
return missing_tensor_infos

def _is_tensor_value_missing(self, tensor_name, step):
"""
Get the status of tensor value of previous step.

@@ -256,27 +270,25 @@ class TensorHandler(StreamHandlerBase):
step (int): The step of the tensor.

Returns:
Union[None, bool], the status of previous tensor value. If True, there is valid previous
tensor value. If False, the tensor value should be queried from client.
Union[None, bool], the status of tensor value. If False, there is valid
tensor value. If True, the tensor value should be queried from client.
If None, ignore.
"""
flag = None
# check if the tensor has previous step value.
prev_step = step - 1
if prev_step < 0:
return flag
tensor = self._get_tensor(tensor_name, step=prev_step)
return bool(tensor and not tensor.empty)

def get_tensor_value_by_name(self, tensor_name, prev=False):
"""Get tensor value by name in numpy type."""
cur_step = self._cur_step
step = cur_step - 1 if prev else cur_step
if step < 0:
log.warning("%d step has no previous value for tensor: %s", cur_step, tensor_name)
return None
tensor = self._get_tensor(tensor_name, step=step)
return bool(not tensor or tensor.empty)

def get_valid_tensor_by_name(self, tensor_name, prev=False):
"""Get tensor value by name in numpy type."""
step = self.prev_step if prev else self.cur_step
if step < 0:
log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name)
return None
tensor = self._get_tensor(tensor_name, step=step)
if tensor and tensor.empty:
log.warning("%s has empty value.", tensor_name)
return None
return tensor

def clean_tensors(self, cur_step):
@@ -313,35 +325,29 @@ class TensorHandler(StreamHandlerBase):
Returns:
dict, the retrieved data.
"""
curr_tensor = self.get_tensor_value_by_name(tensor_name)
prev_tensor = self.get_tensor_value_by_name(tensor_name, prev=True)
curr_tensor = self.get_valid_tensor_by_name(tensor_name)
prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True)
if not (curr_tensor and prev_tensor):
log.error("Get current step and previous step for this tensor name %s failed.", tensor_name)
raise DebuggerParamValueError(f"Get current step and previous step for this tensor name "
f"{tensor_name} failed.")
curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape)
prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape)
# get tensor comparison basic info
tensor_info = curr_tensor.get_basic_info()
if isinstance(tensor_info, dict):
tensor_info.pop('has_prev_step')
tensor_info.pop('value')

tensor_info.pop('has_prev_step')
tensor_info.pop('value')
# calculate tensor comparision object
tensor_comparison = curr_tensor.tensor_comparison
if not tensor_comparison or tensor_comparison.tolerance != tolerance:
if isinstance(curr_tensor.value, np.ndarray) and isinstance(prev_tensor.value, np.ndarray):
if curr_tensor.value.shape != prev_tensor.value.shape:
raise DebuggerParamValueError("The shape of these two step tensors is not the same.")
tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance)
if not tensor_comparison:
stats = TensorUtils.get_statistics_from_tensor(tensor_diff)
tensor_comparison = TensorComparison(tolerance, stats, tensor_diff)
curr_tensor.update_tensor_comparisons(tensor_comparison)
else:
tensor_comparison.update(tolerance=tolerance, value=tensor_diff)
else:
raise DebuggerParamValueError("The type of tensor value should be numpy.ndarray.")

# the type of curr_tensor_slice is one of None, np.ndarray or str
if curr_tensor.value.shape != prev_tensor.value.shape:
raise DebuggerParamValueError("The shape of these two step tensors is not the same.")
tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance)
stats = TensorUtils.get_statistics_from_tensor(tensor_diff)
tensor_comparison = TensorComparison(tolerance, stats, tensor_diff)
curr_tensor.update_tensor_comparisons(tensor_comparison)
# calculate diff value
# the type of curr_tensor_slice is one of np.ndarray or str
if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray):
if not shape:
tensor_diff_slice = tensor_comparison.value
@@ -349,22 +355,25 @@ class TensorHandler(StreamHandlerBase):
tensor_diff_slice = tensor_comparison.value[shape]
result = np.stack([prev_tensor_slice, curr_tensor_slice, tensor_diff_slice], axis=-1)
tensor_info['diff'] = result.tolist()
stats = TensorUtils.get_statistics_from_tensor(tensor_diff_slice)
curr_tensor_stats = TensorUtils.get_statistics_from_tensor(curr_tensor.value)
curr_tensor_slice_stats = TensorUtils.get_statistics_from_tensor(curr_tensor_slice)
prev_tensor_stats = TensorUtils.get_statistics_from_tensor(prev_tensor.value)
prev_tensor_slice_stats = TensorUtils.get_statistics_from_tensor(prev_tensor_slice)
tensor_info['curr_step_statistics'] = TensorUtils.get_statistics_dict(stats=curr_tensor_slice_stats,
overall_stats=curr_tensor_stats)
tensor_info['prev_step_statistics'] = TensorUtils.get_statistics_dict(stats=prev_tensor_slice_stats,
overall_stats=prev_tensor_stats)
tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats=stats,
overall_stats=tensor_comparison.stats)
elif isinstance(curr_tensor_slice, str):
tensor_info['diff'] = curr_tensor_slice
# add comparision statistics
tensor_info.update(self._get_comparison_statistics(curr_tensor, prev_tensor))
reply = {'tensor_value': tensor_info}
return reply

@staticmethod
def _get_comparison_statistics(curr_tensor, prev_tensor):
"""Get comparison statistics."""
stats_info = {}
diff_tensor_stats = curr_tensor.tensor_comparison.stats
curr_tensor_stats = TensorUtils.get_statistics_from_tensor(curr_tensor.value)
prev_tensor_stats = TensorUtils.get_statistics_from_tensor(prev_tensor.value)
stats_info['curr_step_statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=curr_tensor_stats)
stats_info['prev_step_statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=prev_tensor_stats)
stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats)
return stats_info

def get_tensor_statistics(self, tensor_name, node_type):
"""
Get Tensor statistics.
@@ -378,6 +387,6 @@ class TensorHandler(StreamHandlerBase):
"""
res = {}
tensor = self._get_tensor(tensor_name, node_type)
if tensor:
if tensor and not tensor.empty:
res = tensor.get_tensor_statistics()
return res

+ 30
- 4
mindinsight/debugger/stream_operator/tensor_detail_info.py View File

@@ -15,13 +15,14 @@
"""This module is aimed to provide with tensor detail info."""
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import Streams
from mindinsight.debugger.common.utils import Streams, create_view_event_from_tensor_basic_info


class TensorDetailInfo:
"""Manage tensor detail information."""

def __init__(self, cache):
self._put_command = cache.put_command
self._tensor_stream = cache.get_stream_handler(Streams.TENSOR)
self._graph_stream = cache.get_stream_handler(Streams.GRAPH)
self._hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT)
@@ -47,7 +48,7 @@ class TensorDetailInfo:
Get the graph related to specific tensor.

Args:
tensor_name (str): The name of tensor. Format like {node_name}:{slot}.
tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}.
graph_name (str): The graph name.

Returns:
@@ -70,12 +71,16 @@ class TensorDetailInfo:
self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name)
graph = self._graph_stream.get_tensor_graph(tensor_name, graph_name)
# add watchpoint hits info and statistics info for each tensor in tensor graph.
# record missing tensor basic info
nodes = graph.get('graph', {}).get('nodes', [])
missing_tensors = []
for node in nodes:
node['graph_name'] = graph_name
for slot_info in node.get('slots', []):
self._add_watchpoint_hit_info(slot_info, node)
self._add_statistic_info(slot_info, node)
self._add_statistic_info(slot_info, node, missing_tensors)
# query missing tensor values from client
self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name)
return graph

def _add_watchpoint_hit_info(self, slot_info, node):
@@ -89,17 +94,38 @@ class TensorDetailInfo:
tensor_name = ':'.join([node.get('name'), slot_info.get('slot')])
slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name))

def _add_statistic_info(self, slot_info, node):
def _add_statistic_info(self, slot_info, node, missing_tensors):
"""
Get the watchpoint that the tensor hit.

Args:
slot_info (dict): Slot object.
node (dict): Node object.
missing_tensors (list[TensorBasicInfo]): List of missing tensor infos.
"""
tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')])
node_type = node.get('type')
slot_info['statistics'] = self._tensor_stream.get_tensor_statistics(tensor_name, node_type)
if not slot_info.get('statistics'):
log.debug("Get missing tensor basic infos for %s", tensor_name)
cur_missing_tensors = self._tensor_stream.get_missing_tensor_info(tensor_name, node_type)
missing_tensors.extend(cur_missing_tensors)

def _ask_for_missing_tensor_value(self, missing_tensors, tensor_name, graph_name):
"""
Send view command to client to query for missing tensor values.

Args:
missing_tensors (list[TensorBasicInfo]): List of missing tensor basic infos.
tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}.
graph_name (str): The graph name.
"""
if not missing_tensors:
return
log.debug("Ask for tensor value for: %s", missing_tensors)
view_cmd = create_view_event_from_tensor_basic_info(missing_tensors)
self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name})
log.debug("Send view cmd for tensor-graphs.")

def get_tensor_watch_points(self, tensor_name, graph_name):
"""


+ 273
- 0
mindinsight/debugger/stream_operator/training_control_operator.py View File

@@ -0,0 +1,273 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""This module is aimed to deal with controlling commands."""
import enum

from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \
DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type
from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
from mindinsight.utils.exceptions import MindInsightException


@enum.unique
class ControlTypeEnum(enum.Enum):
"""Control Type."""
CONTINUE = 'continue' # continue to run training
PAUSE = 'pause' # suspend training
TERMINATE = 'terminate' # terminate training


class TrainingControlOperator:
"""Control training operator."""
# max step number should be less than int32
_MAX_STEP_NUM = 2 ** 31 - 1

def __init__(self, cache_store):
self._cache_store = cache_store
self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT)
self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH)
self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA)

@staticmethod
def validate_mode(mode):
"""Validate mode."""
enum_members = [item.value for item in ControlTypeEnum]
if mode not in enum_members:
log.error("Invalid control mode %s", mode)
raise DebuggerParamValueError("Invalid control mode.")

def control(self, mode, params):
"""
Control the training process.

Args:
mode (str): Acceptable control command, including `continue`,
`pause` and `terminate`.
params (dict): The control params.

- level (str): The control granularity, `node` level or `step` level.
Default: `step`.
- steps (int): Specify the steps that training should run.
Used when `level` is `step`.
- name (str): Specify the name of the node. Used when `level` is `node`.
- graph_name (str): The graph name.

Returns:
dict, the response.
"""
if mode == ControlTypeEnum.CONTINUE.value:
reply = self.continue_training(params)
else:
mode_mapping = {
ControlTypeEnum.PAUSE.value: self.pause_training,
ControlTypeEnum.TERMINATE.value: self.terminate_training
}
reply = mode_mapping.get(mode)()
return reply

def continue_training(self, params):
"""
Send RunCMD to MindSpore.

Args:
params (dict): The control params.

Returns:
dict, metadata info.
"""
metadata_stream = self._metadata_stream
if metadata_stream.state != ServerStatus.WAITING.value:
self._cache_store.put_data(metadata_stream.get())
log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state)
raise DebuggerContinueError(
"MindSpore is not ready to run or is running currently."
)
metadata_stream.state = ServerStatus.RUNNING.value
try:
self._validate_continue_params(params)
event = self._construct_run_event(params)
self._send_watchpoints()
self._cache_store.put_command(event)
except MindInsightException as err:
log.error("Failed to send run event.")
log.exception(err)
metadata_stream.state = ServerStatus.WAITING.value
raise DebuggerContinueError("Failed to send run command.")
else:
metadata_stream.enable_recheck = False
log.debug("Send the RunCMD to command queue.")
return metadata_stream.get(['state', 'enable_recheck'])

def _validate_continue_params(self, params):
"""
Validate continue params.

Args:
params (dict): The control params.

- level (str): The control granularity, `node`, `step` or `recheck` level.
Default: `step`.
- steps (int): Specify the steps that training should run.
Used when `level` is `step`.
- name (str): Specify the name of the node. Used when `level` is `node`.
- graph_name (str): The graph name.

Raises:
DebuggerParamValueError: Params are invalid.
DebuggerStepNumError: Step number are invalid.
"""
# validate level
level = params.get('level', 'step')
if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]:
log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level)
raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.")

# validate steps
step_num = params.get('steps', 1)
if not isinstance(step_num, int) or not (step_num == -1 or 0 < step_num <= self._MAX_STEP_NUM):
log.error("Invalid step value. Step number should be integer and in [1, 2^31 - 1] or -1.")
raise DebuggerStepNumError

# validate node name
if level == RunLevel.NODE.value:
node_name = params.get('name')
graph_name = params.get('graph_name')
self._validate_continue_node_name(node_name, graph_name)

def _validate_continue_node_name(self, node_name, graph_name):
"""Validate if the node is a leaf node."""
if not node_name:
return
node_type = self._graph_stream.get_node_type(node_name, graph_name)
if is_scope_type(node_type):
log.error("Scope type node has no tensor history.")
raise DebuggerParamValueError("Invalid leaf node name.")

def _construct_run_event(self, params):
"""
Construct run cmd from input control params.

Args:
params (dict): The control params.

- level (str): The control granularity, `node`, `step` or `recheck` level.
Default: `step`.
- steps (int): Specify the steps that training should run.
Used when `level` is `step`.
- name (str): Specify the name of the node. Used when `level` is `node`.
- graph_name (str): The graph name.

Returns:
EventReply, control event with run command.
"""
level = params.get('level', 'step')
# construct run command events
event = get_ack_reply()
if level == 'step':
steps = params.get('steps', 1)
run_cmd = RunCMD(run_level='step', run_steps=steps)
elif level == 'node':
name = params.get('name', '')
graph_name = params.get('graph_name')
if name:
name = self._cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name)
run_cmd = RunCMD(run_level='node', node_name=name)
else:
run_cmd = RunCMD(run_level='recheck')

event.run_cmd.CopyFrom(run_cmd)
log.debug("Construct run event. %s", event)
return event

def _send_watchpoints(self):
"""Send watchpoints to client."""
set_commands = self._watchpoint_stream.get_pending_commands(self._graph_stream)
if not set_commands:
return
for set_cmd in set_commands:
event = get_ack_reply()
event.set_cmd.CopyFrom(set_cmd)
self._cache_store.put_command(event)
log.debug("Send SetCMD to MindSpore. %s", event)
self._watchpoint_stream.sync_set_cmd(set_commands)

def pause_training(self):
"""
Pause the training.

Returns:
dict, metadata info.
"""
metadata_stream = self._metadata_stream
if metadata_stream.state != ServerStatus.RUNNING.value:
self._cache_store.put_data(metadata_stream.get())
log.error("The MindSpore is not running.")
raise DebuggerPauseError("The MindSpore is not running.")
metadata_stream.state = 'waiting'
event = get_ack_reply()
event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
self._cache_store.put_command(event)
metadata_stream.enable_recheck = False
log.debug("Send the Pause command")
return metadata_stream.get(['state', 'enable_recheck'])

def terminate_training(self):
"""
Terminate the training.

Returns:
dict, metadata info.
"""
metadata_stream = self._metadata_stream
metadata_stream.state = 'pending'
self._cache_store.clean_data()
self._cache_store.clean_command()
event = get_ack_reply()
event.exit = True
self._cache_store.put_command(event)
metadata_stream.enable_recheck = False
log.debug("Send the ExitCMD.")
return metadata_stream.get(['state', 'enable_recheck'])

def recheck(self):
"""
Recheck all watchpoints.

Returns:
dict, metadata info.
"""
metadata_stream = self._metadata_stream
# validate backend status is able to recheck watchpoint
if not metadata_stream.enable_recheck:
log.error("Recheck is not available.")
raise DebuggerRecheckError("Recheck is not available.")
metadata_stream.state = ServerStatus.RUNNING.value
metadata_stream.enable_recheck = False
# send updated watchpoint and recheck command
try:
event = self._construct_run_event({'level': 'recheck'})
self._send_watchpoints()
self._cache_store.put_command(event)
except MindInsightException as err:
log.error("Failed to send recheck event.")
log.exception(err)
metadata_stream.state = ServerStatus.WAITING.value
metadata_stream.enable_recheck = True
raise DebuggerContinueError("Failed to send recheck command.")
else:
log.debug("Send the recheck to command queue.")
return metadata_stream.get(['state', 'enable_recheck'])

+ 2
- 0
mindinsight/utils/tensor.py View File

@@ -316,6 +316,8 @@ class TensorUtils:
Returns:
dict, overall statistics.
"""
if not overall_stats:
return {}
res = {
"overall_max": float(overall_stats.max),
"overall_min": float(overall_stats.min),


+ 0
- 21
tests/st/func/debugger/expect_results/restful_results/compare_tensors.json View File

@@ -17,13 +17,6 @@
]
],
"curr_step_statistics": {
"max": 6.0,
"min": 1.0,
"avg": 3.5,
"count": 6,
"nan_count": 0,
"neg_inf_count": 0,
"pos_inf_count": 0,
"overall_max": 6.0,
"overall_min": 1.0,
"overall_avg": 3.5,
@@ -36,13 +29,6 @@
"overall_pos_zero_count": 6.0
},
"prev_step_statistics": {
"max": 6.0,
"min": 1.0,
"avg": 3.5,
"count": 6,
"nan_count": 0,
"neg_inf_count": 0,
"pos_inf_count": 0,
"overall_max": 6.0,
"overall_min": 1.0,
"overall_avg": 3.5,
@@ -55,13 +41,6 @@
"overall_pos_zero_count": 6.0
},
"statistics": {
"max": 0.0,
"min": 0.0,
"avg": 0.0,
"count": 6,
"nan_count": 0,
"neg_inf_count": 0,
"pos_inf_count": 0,
"overall_max": 0.0,
"overall_min": 0.0,
"overall_avg": 0.0,


+ 29
- 1
tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json View File

@@ -1 +1,29 @@
{"tensor_value": {"full_name": "Default/TransData-op99:0", "step": 1, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "statistics": {"max": 6.0, "min": 5.0, "avg": 5.5, "count": 2, "nan_count": 0, "neg_inf_count": 0, "pos_inf_count": 0, "overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, "overall_count": 6, "overall_nan_count": 0, "overall_neg_inf_count": 0, "overall_pos_inf_count": 0, "overall_zero_count": 0.0, "overall_neg_zero_count": 0.0, "overall_pos_zero_count": 6.0}, "value": [5.0, 6.0], "name": "Default/TransData-op99:0"}}
{
"tensor_value": {
"full_name": "Default/TransData-op99:0",
"step": 1,
"dtype": "DT_FLOAT32",
"shape": [
2,
3
],
"has_prev_step": false,
"statistics": {
"overall_max": 6.0,
"overall_min": 1.0,
"overall_avg": 3.5,
"overall_count": 6,
"overall_nan_count": 0,
"overall_neg_inf_count": 0,
"overall_pos_inf_count": 0,
"overall_zero_count": 0.0,
"overall_neg_zero_count": 0.0,
"overall_pos_zero_count": 6.0
},
"value": [
5.0,
6.0
],
"name": "Default/TransData-op99:0"
}
}

+ 1
- 1
tests/st/func/debugger/mock_ms_client.py View File

@@ -217,5 +217,5 @@ class MockDebuggerClientThread:
return self._debugger_client_thread

def __exit__(self, exc_type, exc_val, exc_tb):
self._debugger_client_thread.join(timeout=3)
self._debugger_client_thread.join(timeout=2)
self._debugger_client.flag = False

+ 1
- 0
tests/ut/debugger/stream_handler/test_metadata_handler.py View File

@@ -17,6 +17,7 @@ from mindinsight.debugger.common.utils import ServerStatus
from mindinsight.debugger.stream_handler.metadata_handler import MetadataHandler
from mindinsight.debugger.proto.debug_grpc_pb2 import Metadata


class TestMetadataHandler:
"""test class for MetadataHandler"""
def setup_method(self):


+ 1
- 1
tests/ut/debugger/stream_handler/test_tensor_handler.py View File

@@ -40,7 +40,7 @@ class TestTensorHandler:

def test_get_tensor_value_by_name_none(self):
"""Test get_tensor_value_by_name."""
res = self.tensor_handler.get_tensor_value_by_name('tensor_name', True)
res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', True)
assert res is None

@mock.patch.object(log, "error")


+ 15
- 0
tests/ut/debugger/stream_operator/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test for debugger stream operator."""

+ 66
- 0
tests/ut/debugger/stream_operator/test_training_control_operator.py View File

@@ -0,0 +1,66 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test debugger training control operator.
Usage:
pytest tests/ut/debugger/stream_operator/test_training_control_operator.py
"""
from unittest import mock

import pytest

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
from mindinsight.debugger.stream_handler import GraphHandler, MetadataHandler
from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator


class TestTrainingControlOperator:
"""Test debugger server."""

@classmethod
def setup_class(cls):
"""Initialize for test class."""
cls._server = None

def setup_method(self):
"""Prepare debugger server object."""
cache_store = DebuggerCache()
cache_store.initialize()
self._server = TrainingControlOperator(cache_store)

@mock.patch.object(GraphHandler, 'get_node_type')
def test_validate_leaf_name(self, *args):
"""Test validate leaf name."""
args[0].return_value = 'name_scope'
with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'):
self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name')

@pytest.mark.parametrize('mode, cur_state, state', [
('continue', 'waiting', 'running'),
('pause', 'running', 'waiting'),
('terminate', 'waiting', 'pending')])
def test_control(self, mode, cur_state, state):
"""Test control request."""
with mock.patch.object(MetadataHandler, 'state', cur_state):
res = self._server.control(mode=mode, params={})
assert res == {'metadata': {'enable_recheck': False, 'state': state}}

def test_construct_run_event(self):
"""Test construct run event."""
res = self._server._construct_run_event({'level': 'node'})
assert res.run_cmd == RunCMD(run_level='node', node_name='')

+ 5
- 5
tests/ut/debugger/test_debugger_grpc_server.py View File

@@ -68,7 +68,7 @@ class MockDataGenerator:
view_event = get_ack_reply()
ms_tensor = view_event.view_cmd.tensors.add()
ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0'
event = {'view_cmd': view_event, 'node_name': 'mock_node_name'}
event = {'view_cmd': view_event, 'node_name': 'mock_node_name', 'graph_name': 'mock_graph_name'}
return event

@staticmethod
@@ -180,10 +180,10 @@ class TestDebuggerGrpcServer:
def test_deal_with_old_command_with_view_cmd(self, *args):
"""Test deal with view command."""
cmd = MockDataGenerator.get_view_cmd()
args[1].return_value = ('0', cmd)
args[1].return_value = ('0', cmd.copy())
res = self._server._deal_with_old_command()
assert res == cmd.get('view_cmd')
expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True}
assert res == cmd.pop('view_cmd')
expect_received_view_cmd = {'node_info': cmd, 'wait_for_tensor': True}
assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd

@mock.patch.object(DebuggerCache, 'get_command')
@@ -201,7 +201,7 @@ class TestDebuggerGrpcServer:
"""Test wait for run command."""
pause_cmd = MockDataGenerator.get_run_cmd(steps=0)
empty_view_cmd = MockDataGenerator.get_view_cmd()
empty_view_cmd.pop('node_name')
empty_view_cmd.pop('view_cmd')
run_cmd = MockDataGenerator.get_run_cmd(steps=2)
args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)]
setattr(self._server, '_status', ServerStatus.WAITING)


+ 0
- 24
tests/ut/debugger/test_debugger_server.py View File

@@ -32,7 +32,6 @@ from mindinsight.debugger.common.utils import Streams
from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_server import DebuggerServer
from mindinsight.debugger.debugger_server import grpc_server_base
from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \
TensorHandler
from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history
@@ -154,13 +153,6 @@ class TestDebuggerServer:
res = self._server.retrieve_tensor_history('mock_node_name')
compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json')

@mock.patch.object(GraphHandler, 'get_node_type')
def test_validate_leaf_name(self, *args):
"""Test validate leaf name."""
args[0].return_value = 'name_scope'
with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'):
self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name')

@mock.patch.object(TensorHandler, 'get')
@mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name')
def test_retrieve_tensor_value(self, *args):
@@ -187,7 +179,6 @@ class TestDebuggerServer:
res = self._server._retrieve_watchpoint({'watch_point_id': 1})
assert res == mock_watchpoint

@mock.patch.object(DebuggerServer, '_validate_continue_node_name')
@mock.patch.object(DebuggerServer, '_get_tensor_history')
@mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}})
def test_retrieve_watchpoint_hit(self, *args):
@@ -238,18 +229,3 @@ class TestDebuggerServer:
args[0].return_value = None
res = self._server.delete_watchpoint(1)
assert res == {'metadata': {'enable_recheck': True, 'state': 'waiting'}}

@pytest.mark.parametrize('mode, cur_state, state', [
('continue', 'waiting', 'running'),
('pause', 'running', 'waiting'),
('terminate', 'waiting', 'pending')])
def test_control(self, mode, cur_state, state):
"""Test control request."""
with mock.patch.object(MetadataHandler, 'state', cur_state):
res = self._server.control({'mode': mode})
assert res == {'metadata': {'enable_recheck': False, 'state': state}}

def test_construct_run_event(self):
"""Test construct run event."""
res = self._server._construct_run_event({'level': 'node'})
assert res.run_cmd == RunCMD(run_level='node', node_name='')

Loading…
Cancel
Save