diff --git a/mindinsight/debugger/common/exceptions/error_code.py b/mindinsight/debugger/common/exceptions/error_code.py index 615f18ff..2528ae39 100644 --- a/mindinsight/debugger/common/exceptions/error_code.py +++ b/mindinsight/debugger/common/exceptions/error_code.py @@ -27,7 +27,6 @@ class DebuggerErrors(DebuggerErrorCodes): """Debugger error codes.""" PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK - STEP_NUM_ERROR = 2 | _PARAM_ERROR_MASK NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR @@ -40,6 +39,8 @@ class DebuggerErrors(DebuggerErrorCodes): PAUSE_ERROR = 4 | _DEBUGGER_RUNNING_ERROR COMPARE_TENSOR_ERROR = 5 | _DEBUGGER_RUNNING_ERROR RECHECK_ERROR = 6 | _DEBUGGER_RUNNING_ERROR + TENSOR_GRAPH_ERROR = 7 | _DEBUGGER_RUNNING_ERROR + TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR @unique @@ -56,3 +57,5 @@ class DebuggerErrorMsg(Enum): CONTINUE_ERROR = "Continue debugging failed. {}" PAUSE_ERROR = "Pause debugging failed. {}" RECHECK_ERROR = "Recheck failed. {}" + TENSOR_GRAPH_ERROR = "Get tensor graphs failed." + TENSOR_HIT_ERROR = "Get tensor hits failed." diff --git a/mindinsight/debugger/common/exceptions/exceptions.py b/mindinsight/debugger/common/exceptions/exceptions.py index 5f900e3e..0c27533e 100644 --- a/mindinsight/debugger/common/exceptions/exceptions.py +++ b/mindinsight/debugger/common/exceptions/exceptions.py @@ -146,3 +146,25 @@ class DebuggerStepNumError(MindInsightException): message="The type of step number should be int32.", http_code=400 ) + + +class DebuggerTensorGraphError(MindInsightException): + """The error about comparing tensors.""" + + def __init__(self): + super(DebuggerTensorGraphError, self).__init__( + error=DebuggerErrors.TENSOR_GRAPH_ERROR, + message=DebuggerErrorMsg.TENSOR_GRAPH_ERROR.value, + http_code=400 + ) + + +class DebuggerTensorHitError(MindInsightException): + """The error about comparing tensors.""" + + def __init__(self): + super(DebuggerTensorHitError, self).__init__( + error=DebuggerErrors.TENSOR_HIT_ERROR, + message=DebuggerErrorMsg.TENSOR_HIT_ERROR.value, + http_code=400 + ) diff --git a/mindinsight/debugger/common/utils.py b/mindinsight/debugger/common/utils.py index 7011bfe5..9dceaf58 100644 --- a/mindinsight/debugger/common/utils.py +++ b/mindinsight/debugger/common/utils.py @@ -115,29 +115,29 @@ def wrap_reply_response(error_code=None, error_message=None): return reply -def create_view_event_from_tensor_history(tensor_history): +def create_view_event_from_tensor_basic_info(tensors_info): """ Create view event reply according to tensor names. Args: - tensor_history (list[dict]): The list of tensor history. Each element has keys: - `name`, `node_type`. + tensors_info (list[TensorBasicInfo]): The list of TensorBasicInfo. Each element has keys: + `full_name`, `node_type`, `iter`. Returns: EventReply, the event reply with view cmd. """ view_event = get_ack_reply() - for tensor_info in tensor_history: - node_type = tensor_info.get('node_type') + for tensor_info in tensors_info: + node_type = tensor_info.node_type if node_type == NodeTypeEnum.CONST.value: continue - truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value - tensor_name = tensor_info.get('full_name', '') + truncate_tag = node_type == NodeTypeEnum.PARAMETER.value + tensor_name = tensor_info.full_name # create view command ms_tensor = view_event.view_cmd.tensors.add() ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1) ms_tensor.truncate = truncate_tag - ms_tensor.iter = 'prev' if tensor_info.get('iter') else '' + ms_tensor.iter = tensor_info.iter return view_event diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_grpc_server.py index eb034972..79fa081e 100644 --- a/mindinsight/debugger/debugger_grpc_server.py +++ b/mindinsight/debugger/debugger_grpc_server.py @@ -159,15 +159,15 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): def _send_received_tensor_tag(self): """Send received_finish_tag.""" - node_name = self._received_view_cmd.get('node_name') - if not node_name or self._received_view_cmd.get('wait_for_tensor'): + node_info = self._received_view_cmd.get('node_info') + if not node_info or self._received_view_cmd.get('wait_for_tensor'): return metadata = self._cache_store.get_stream_handler(Streams.METADATA).get(['step', 'state']) - ret = {'receive_tensor': {'node_name': node_name}} + ret = {'receive_tensor': node_info.copy()} ret.update(metadata) self._cache_store.put_data(ret) self._received_view_cmd.clear() - log.debug("Send receive tensor flag for %s", node_name) + log.debug("Send receive tensor flag for %s", node_info) def _send_watchpoint_hit_flag(self): """Send Watchpoint hit flag.""" @@ -281,14 +281,26 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): return event def _deal_with_view_cmd(self, event): - """Deal with view cmd.""" - view_cmd = event.get('view_cmd') - node_name = event.get('node_name') - log.debug("Receive view cmd for node: %s.", node_name) - if not (view_cmd and node_name): + """ + Deal with view cmd. + + Args: + event (dict): View command params. + + - view_cmd (EventReply): EventReply with view command. + - node_name (str): The center node name for view command. + - tensor_name (str): The center tensor name for view command. + - graph_name (str): The graph name of center node. + + Returns: + EventReply, view command to be sent to client. + """ + view_cmd = event.pop('view_cmd', None) + log.debug("Receive view cmd for node: %s.", event) + if not (view_cmd and event): log.debug("Invalid view command. Ignore it.") return None - self._received_view_cmd['node_name'] = node_name + self._received_view_cmd['node_info'] = event self._received_view_cmd['wait_for_tensor'] = True return view_cmd @@ -395,6 +407,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): if tensor.finished: update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) if self._received_view_cmd.get('wait_for_tensor') and update_flag: + # update_flag is used to avoid querying empty tensors again self._received_view_cmd['wait_for_tensor'] = False log.debug("Set wait for tensor flag to False.") tensor_construct = [] diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index dea59b23..84c35b20 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -16,35 +16,33 @@ import signal from concurrent import futures from threading import Thread + import grpc -from mindinsight.conditionmgr.conditionmgr import ConditionMgr +from mindinsight.conditionmgr.common.utils import NodeBasicInfo from mindinsight.conditionmgr.condition import ConditionContext, ConditionIdEnum +from mindinsight.conditionmgr.conditionmgr import ConditionMgr from mindinsight.conditionmgr.recommender import recommend_watchpoints from mindinsight.conf import settings from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.utils.tools import to_float from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ - DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \ - DebuggerCompareTensorError, DebuggerRecheckError, DebuggerStepNumError + DebuggerDeleteWatchPointError, DebuggerCompareTensorError, DebuggerTensorGraphError, \ + DebuggerTensorHitError from mindinsight.debugger.common.log import LOGGER as log -from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ - create_view_event_from_tensor_history, Streams, is_scope_type, RunLevel -from mindinsight.conditionmgr.common.utils import NodeBasicInfo +from mindinsight.debugger.common.utils import ServerStatus, \ + create_view_event_from_tensor_basic_info, Streams from mindinsight.debugger.debugger_cache import DebuggerCache from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base -from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo -from mindinsight.utils.exceptions import MindInsightException +from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR class DebuggerServer: """The server manager of debugger.""" - # max step number should be less than int32 - _MAX_STEP_NUM = 2 ** 31 - 1 def __init__(self, grpc_port=None): self.grpc_port = grpc_port @@ -355,7 +353,7 @@ class DebuggerServer: graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) tensor_history = graph_stream.get_tensor_history(node_name, graph_name) # add tensor value for tensor history - self._add_tensor_value_for_tensor_history(tensor_history, node_name) + self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name) # add hit label for tensor history watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) watchpoint_hit_stream.update_tensor_history(tensor_history) @@ -364,13 +362,14 @@ class DebuggerServer: tensor_history.update(metadata) return tensor_history - def _add_tensor_value_for_tensor_history(self, tensor_history, node_name): + def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name): """ Add tensor value for_tensor_history and send ViewCMD if tensor value missed. Args: tensor_history (list[dict]): A list of tensor info, including name and type. node_name (str): The UI node name. + graph_name (str): The graph name. Default: None. Returns: dict, the tensor info. @@ -378,8 +377,8 @@ class DebuggerServer: tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) missed_tensors = tensor_stream.update_tensor_history(tensor_history) if missed_tensors: - view_cmd = create_view_event_from_tensor_history(missed_tensors) - self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name}) + view_cmd = create_view_event_from_tensor_basic_info(missed_tensors) + self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name}) log.debug("Send view cmd.") def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False): @@ -679,189 +678,10 @@ class DebuggerServer: dict, the response. """ log.info("Receive control request: %s.", params) - mode = params.get('mode') - metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) - if mode == 'continue': - reply = self._continue(metadata_stream, params) - elif mode in ['pause', 'terminate']: - mode_mapping = { - 'pause': self._pause, - 'terminate': self._terminate - } - reply = mode_mapping.get(mode)(metadata_stream) - else: - log.error("Invalid control mode %s", mode) - raise DebuggerParamValueError("Invalid control mode.") - - return reply - - def _continue(self, metadata_stream, params): - """ - Send RunCMD to MindSpore. - - Args: - metadata_stream (MetadataHandler): The metadata_handler - params (dict): The control params. - - Returns: - dict, metadata info. - """ - if metadata_stream.state != ServerStatus.WAITING.value: - self.cache_store.put_data(metadata_stream.get()) - log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) - raise DebuggerContinueError( - "MindSpore is not ready to run or is running currently." - ) - metadata_stream.state = ServerStatus.RUNNING.value - try: - self._validate_continue_params(params) - event = self._construct_run_event(params) - self._send_watchpoints() - self.cache_store.put_command(event) - except MindInsightException as err: - log.error("Failed to send run event.") - log.exception(err) - metadata_stream.state = ServerStatus.WAITING.value - raise DebuggerContinueError("Failed to send run command.") - else: - metadata_stream.enable_recheck = False - log.debug("Send the RunCMD to command queue.") - return metadata_stream.get(['state', 'enable_recheck']) - - def _validate_continue_params(self, params): - """ - Validate continue params. - - Args: - params (dict): The control params. - - - level (str): The control granularity, `node`, `step` or `recheck` level. - Default: `step`. - - steps (int): Specify the steps that training should run. - Used when `level` is `step`. - - name (str): Specify the name of the node. Used when `level` is `node`. - - graph_name (str): The graph name. - - Raises: - DebuggerParamValueError: Params are invalid. - """ - # validate level - level = params.get('level', 'step') - if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]: - log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level) - raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.") - - # validate steps - step_num = params.get('steps', 1) - if not isinstance(step_num, int) or not (step_num == -1 or 0 < step_num <= self._MAX_STEP_NUM): - log.error("Invalid step value. Step number should be integer and in [1, 2^31 - 1] or -1.") - raise DebuggerStepNumError - - # validate node name - if level == RunLevel.NODE.value: - node_name = params.get('name') - graph_name = params.get('graph_name') - self._validate_continue_node_name(node_name, graph_name) - - def _construct_run_event(self, params): - """ - Construct run cmd from input control params. - - Args: - params (dict): The control params. - - - level (str): The control granularity, `node`, `step` or `recheck` level. - Default: `step`. - - steps (int): Specify the steps that training should run. - Used when `level` is `step`. - - name (str): Specify the name of the node. Used when `level` is `node`. - - graph_name (str): The graph name. - - Returns: - EventReply, control event with run command. - """ - level = params.get('level', 'step') - # construct run command events - event = get_ack_reply() - if level == 'step': - steps = params.get('steps', 1) - run_cmd = RunCMD(run_level='step', run_steps=steps) - elif level == 'node': - name = params.get('name', '') - graph_name = params.get('graph_name') - if name: - name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name) - run_cmd = RunCMD(run_level='node', node_name=name) - else: - run_cmd = RunCMD(run_level='recheck') - - event.run_cmd.CopyFrom(run_cmd) - log.debug("Construct run event. %s", event) - return event - - def _validate_continue_node_name(self, node_name, graph_name): - """Validate if the node is a leaf node.""" - if not node_name: - return - graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) - node_type = graph_stream.get_node_type(node_name, graph_name) - if is_scope_type(node_type): - log.error("Scope type node has no tensor history.") - raise DebuggerParamValueError("Invalid leaf node name.") - - def _send_watchpoints(self): - """Set watchpoints.""" - watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) - set_commands = watchpoint_stream.get_pending_commands(self.cache_store.get_stream_handler(Streams.GRAPH)) - if set_commands: - for set_cmd in set_commands: - event = get_ack_reply() - event.set_cmd.CopyFrom(set_cmd) - self.cache_store.put_command(event) - watchpoint_stream.sync_set_cmd(set_commands) - log.debug("Send SetCMD to MindSpore. %s", event) - - def _pause(self, metadata_stream): - """ - Pause the training. - - Args: - metadata_stream (MetadataHandler): The metadata stream handler. - - Returns: - dict, metadata info. - """ - if metadata_stream.state != ServerStatus.RUNNING.value: - self.cache_store.put_data(metadata_stream.get()) - log.error("The MindSpore is not running.") - raise DebuggerPauseError("The MindSpore is not running.") - metadata_stream.state = 'waiting' - event = get_ack_reply() - event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) - self.cache_store.put_command(event) - metadata_stream.enable_recheck = False - log.debug("Send the Pause command") - return metadata_stream.get(['state', 'enable_recheck']) - - def _terminate(self, metadata_stream): - """ - Terminate the training. - - Args: - metadata_stream (MetadataHandler): The metadata stream handler. - - Returns: - dict, metadata info. - """ - metadata_stream.state = 'pending' - self.cache_store.clean_data() - self.cache_store.clean_command() - event = get_ack_reply() - event.exit = True - self.cache_store.put_command(event) - metadata_stream.enable_recheck = False - log.debug("Send the ExitCMD.") - return metadata_stream.get(['state', 'enable_recheck']) + mode = params.pop('mode', None) + training_controller = TrainingControlOperator(self.cache_store) + training_controller.validate_mode(mode) + return training_controller.control(mode, params) def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False): """ @@ -904,27 +724,7 @@ class DebuggerServer: Returns: dict, metadata info. """ - metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) - # validate backend status is able to recheck watchpoint - if not metadata_stream.enable_recheck: - log.error("Recheck is not available.") - raise DebuggerRecheckError("Recheck is not available.") - metadata_stream.state = ServerStatus.RUNNING.value - metadata_stream.enable_recheck = False - # send updated watchpoint and recheck command - try: - event = self._construct_run_event({'level': 'recheck'}) - self._send_watchpoints() - self.cache_store.put_command(event) - except MindInsightException as err: - log.error("Failed to send recheck event.") - log.exception(err) - metadata_stream.state = ServerStatus.WAITING.value - metadata_stream.enable_recheck = True - raise DebuggerContinueError("Failed to send run command.") - else: - log.debug("Send the recheck to command queue.") - return metadata_stream.get(['state', 'enable_recheck']) + return TrainingControlOperator(self.cache_store).recheck() def retrieve_tensor_graph(self, tensor_name, graph_name): """ @@ -937,6 +737,9 @@ class DebuggerServer: Returns: dict, tensor graph object. """ + if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value: + log.error("Failed to get tensor graph the MindSpore is not in waiting state.") + raise DebuggerTensorGraphError log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name) tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name) return tensor_graph_ops @@ -952,6 +755,9 @@ class DebuggerServer: Returns: dict, tensor hit info. """ + if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value: + log.error("Failed to get tensor hits as the MindSpore is not in waiting state.") + raise DebuggerTensorHitError log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name) watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name) return {'watch_points': watch_points} diff --git a/mindinsight/debugger/stream_cache/tensor.py b/mindinsight/debugger/stream_cache/tensor.py index 1f31ce2f..dca1a167 100644 --- a/mindinsight/debugger/stream_cache/tensor.py +++ b/mindinsight/debugger/stream_cache/tensor.py @@ -130,6 +130,16 @@ class OpTensor(BaseTensor): """The property of tensor stats.""" return self._stats + @stats.setter + def stats(self, stats): + """ + Update tensor stats. + + Args: + stats (Statistics): Instance of Statistics. + """ + self._stats = stats + @property def tensor_comparison(self): """The property of tensor_comparison.""" @@ -167,15 +177,10 @@ class OpTensor(BaseTensor): res = {} # the type of tensor_value is one of None, np.ndarray or str if isinstance(tensor_value, np.ndarray): - statistics = TensorUtils.get_statistics_from_tensor(tensor_value) - if not self.stats: - self.update_tensor_stats(TensorUtils.get_statistics_from_tensor(self.value)) - res['statistics'] = TensorUtils.get_statistics_dict(stats=statistics, overall_stats=self.stats) res['value'] = tensor_value.tolist() elif isinstance(tensor_value, str): res['value'] = tensor_value - res['statistics'] = TensorUtils.get_overall_statistic_dict(self._stats) - + res['statistics'] = self.get_tensor_statistics() return res def get_tensor_statistics(self): @@ -185,9 +190,11 @@ class OpTensor(BaseTensor): Returns: dict, overall statistics. """ - if not self._stats: - self._stats = TensorUtils.get_statistics_from_tensor(self.value) - statistics = TensorUtils.get_overall_statistic_dict(self._stats) + if self.empty: + return {} + if not self.stats: + self.stats = TensorUtils.get_statistics_from_tensor(self.value) + statistics = TensorUtils.get_overall_statistic_dict(self.stats) return statistics def update_tensor_comparisons(self, tensor_comparison): @@ -200,16 +207,6 @@ class OpTensor(BaseTensor): """ self._tensor_comparison = tensor_comparison - def update_tensor_stats(self, stats): - """ - Update tensor stats. - - Args: - stats (Statistics) instance of Statistics. - - """ - self._stats = stats - def get_tensor_value_by_shape(self, shape=None): """ Get tensor value by shape. diff --git a/mindinsight/debugger/stream_handler/graph_handler.py b/mindinsight/debugger/stream_handler/graph_handler.py index 510ad0e2..9de46472 100644 --- a/mindinsight/debugger/stream_handler/graph_handler.py +++ b/mindinsight/debugger/stream_handler/graph_handler.py @@ -467,7 +467,7 @@ class GraphHandler(StreamHandlerBase): Get tensor graph according to node name. Args: - tensor_name (str): Tensor name, format is "node_name:". + tensor_name (str): Tensor name from UI, format is "node_name:slot". graph_name (str): The relative graph_name of the node. Default: None. Returns: @@ -624,7 +624,6 @@ class GraphHandler(StreamHandlerBase): graph_name = self.graph_names[0] return graph_name - def _add_graph_scope_for_nodes(self, nodes, graph_name): """ Add graph scope for nodes. diff --git a/mindinsight/debugger/stream_handler/tensor_handler.py b/mindinsight/debugger/stream_handler/tensor_handler.py index 1773b92b..e6478683 100644 --- a/mindinsight/debugger/stream_handler/tensor_handler.py +++ b/mindinsight/debugger/stream_handler/tensor_handler.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """Define the tensor stream handler.""" +from collections import namedtuple + import numpy as np from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum @@ -23,6 +25,7 @@ from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase from mindinsight.utils.tensor import TensorUtils, TensorComparison +TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) class TensorHandler(StreamHandlerBase): """Metadata Handler.""" @@ -170,7 +173,7 @@ class TensorHandler(StreamHandlerBase): log.error("No tensor named %s at the step %s", name, step) raise DebuggerParamValueError("No tensor named {}".format(name)) tensor_info = tensor.get_full_info(shape) - self._update_has_prev_step_field(tensor_info, name, node_type, step) + self._update_has_prev_step_field(tensor_info, name, node_type) return {'tensor_value': tensor_info} def _get_tensor(self, tensor_name, node_type=None, step=None): @@ -219,35 +222,46 @@ class TensorHandler(StreamHandlerBase): tensor_name = tensor_info.get('full_name') node_type = tensor_info.get('node_type') basic_info = self._get_basic_info(tensor_name, node_type) - flag = self._update_has_prev_step_field(basic_info, tensor_name, node_type, self.cur_step) - if flag is False: - missed_tensor = tensor_info.copy() - missed_tensor['iter'] = 'prev' - missed_tensors.append(missed_tensor) - log.debug("Add previous view cmd for %s", tensor_name) # add `has_prev_step` field to tensor basic info. + missing_tensor_infos = self._update_has_prev_step_field(basic_info, tensor_name, node_type) if basic_info: tensor_info.update(basic_info) - if basic_info.get('value') is None: - missed_tensors.append(tensor_info) - log.debug("Add view cmd for %s", tensor_name) - else: - missed_tensors.append(tensor_info) - log.debug("Add view cmd for %s", tensor_name) + if missing_tensor_infos: + missed_tensors.extend(missing_tensor_infos) return missed_tensors - def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type, step): + def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): """Update has_prev_step field in tensor info.""" - flag = None - cur_tensor_value = bool(tensor_info and tensor_info.get('value') is not None) - if node_type == NodeTypeEnum.PARAMETER.value: - flag = self._get_prev_tensor_value_status(tensor_name, step) - if flag and cur_tensor_value: - tensor_info['has_prev_step'] = True - return flag + missing_tensor_infos = self.get_missing_tensor_info(tensor_name, node_type) + if not missing_tensor_infos and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0: + tensor_info['has_prev_step'] = True + return missing_tensor_infos - def _get_prev_tensor_value_status(self, tensor_name, step): + def get_missing_tensor_info(self, tensor_name, node_type): + """ + Get missing tensor infos. + + Args: + tensor_name (str): The full name of Tensor. + node_type (str): The type of the relative node. + + Returns: + list, list of missing tensor basic information. + """ + step = self.cur_step + missing_tensor_infos = [] + # check the current step value is missing + if self._is_tensor_value_missing(tensor_name, step): + missing_tensor_infos.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='')) + log.debug("Add current step view cmd for %s", tensor_name) + # check the previous step value is missing + if node_type == NodeTypeEnum.PARAMETER.value and self._is_tensor_value_missing(tensor_name, step - 1): + missing_tensor_infos.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='prev')) + log.debug("Add previous view cmd for %s", tensor_name) + return missing_tensor_infos + + def _is_tensor_value_missing(self, tensor_name, step): """ Get the status of tensor value of previous step. @@ -256,27 +270,25 @@ class TensorHandler(StreamHandlerBase): step (int): The step of the tensor. Returns: - Union[None, bool], the status of previous tensor value. If True, there is valid previous - tensor value. If False, the tensor value should be queried from client. + Union[None, bool], the status of tensor value. If False, there is valid + tensor value. If True, the tensor value should be queried from client. If None, ignore. """ - flag = None - # check if the tensor has previous step value. - prev_step = step - 1 - if prev_step < 0: - return flag - tensor = self._get_tensor(tensor_name, step=prev_step) - return bool(tensor and not tensor.empty) - - def get_tensor_value_by_name(self, tensor_name, prev=False): - """Get tensor value by name in numpy type.""" - cur_step = self._cur_step - step = cur_step - 1 if prev else cur_step if step < 0: - log.warning("%d step has no previous value for tensor: %s", cur_step, tensor_name) return None tensor = self._get_tensor(tensor_name, step=step) + return bool(not tensor or tensor.empty) + def get_valid_tensor_by_name(self, tensor_name, prev=False): + """Get tensor value by name in numpy type.""" + step = self.prev_step if prev else self.cur_step + if step < 0: + log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name) + return None + tensor = self._get_tensor(tensor_name, step=step) + if tensor and tensor.empty: + log.warning("%s has empty value.", tensor_name) + return None return tensor def clean_tensors(self, cur_step): @@ -313,35 +325,29 @@ class TensorHandler(StreamHandlerBase): Returns: dict, the retrieved data. """ - curr_tensor = self.get_tensor_value_by_name(tensor_name) - prev_tensor = self.get_tensor_value_by_name(tensor_name, prev=True) + curr_tensor = self.get_valid_tensor_by_name(tensor_name) + prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True) if not (curr_tensor and prev_tensor): log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " f"{tensor_name} failed.") curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape) prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape) + # get tensor comparison basic info tensor_info = curr_tensor.get_basic_info() - if isinstance(tensor_info, dict): - tensor_info.pop('has_prev_step') - tensor_info.pop('value') - + tensor_info.pop('has_prev_step') + tensor_info.pop('value') + # calculate tensor comparision object tensor_comparison = curr_tensor.tensor_comparison if not tensor_comparison or tensor_comparison.tolerance != tolerance: - if isinstance(curr_tensor.value, np.ndarray) and isinstance(prev_tensor.value, np.ndarray): - if curr_tensor.value.shape != prev_tensor.value.shape: - raise DebuggerParamValueError("The shape of these two step tensors is not the same.") - tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance) - if not tensor_comparison: - stats = TensorUtils.get_statistics_from_tensor(tensor_diff) - tensor_comparison = TensorComparison(tolerance, stats, tensor_diff) - curr_tensor.update_tensor_comparisons(tensor_comparison) - else: - tensor_comparison.update(tolerance=tolerance, value=tensor_diff) - else: - raise DebuggerParamValueError("The type of tensor value should be numpy.ndarray.") - - # the type of curr_tensor_slice is one of None, np.ndarray or str + if curr_tensor.value.shape != prev_tensor.value.shape: + raise DebuggerParamValueError("The shape of these two step tensors is not the same.") + tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance) + stats = TensorUtils.get_statistics_from_tensor(tensor_diff) + tensor_comparison = TensorComparison(tolerance, stats, tensor_diff) + curr_tensor.update_tensor_comparisons(tensor_comparison) + # calculate diff value + # the type of curr_tensor_slice is one of np.ndarray or str if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray): if not shape: tensor_diff_slice = tensor_comparison.value @@ -349,22 +355,25 @@ class TensorHandler(StreamHandlerBase): tensor_diff_slice = tensor_comparison.value[shape] result = np.stack([prev_tensor_slice, curr_tensor_slice, tensor_diff_slice], axis=-1) tensor_info['diff'] = result.tolist() - stats = TensorUtils.get_statistics_from_tensor(tensor_diff_slice) - curr_tensor_stats = TensorUtils.get_statistics_from_tensor(curr_tensor.value) - curr_tensor_slice_stats = TensorUtils.get_statistics_from_tensor(curr_tensor_slice) - prev_tensor_stats = TensorUtils.get_statistics_from_tensor(prev_tensor.value) - prev_tensor_slice_stats = TensorUtils.get_statistics_from_tensor(prev_tensor_slice) - tensor_info['curr_step_statistics'] = TensorUtils.get_statistics_dict(stats=curr_tensor_slice_stats, - overall_stats=curr_tensor_stats) - tensor_info['prev_step_statistics'] = TensorUtils.get_statistics_dict(stats=prev_tensor_slice_stats, - overall_stats=prev_tensor_stats) - tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats=stats, - overall_stats=tensor_comparison.stats) elif isinstance(curr_tensor_slice, str): tensor_info['diff'] = curr_tensor_slice + # add comparision statistics + tensor_info.update(self._get_comparison_statistics(curr_tensor, prev_tensor)) reply = {'tensor_value': tensor_info} return reply + @staticmethod + def _get_comparison_statistics(curr_tensor, prev_tensor): + """Get comparison statistics.""" + stats_info = {} + diff_tensor_stats = curr_tensor.tensor_comparison.stats + curr_tensor_stats = TensorUtils.get_statistics_from_tensor(curr_tensor.value) + prev_tensor_stats = TensorUtils.get_statistics_from_tensor(prev_tensor.value) + stats_info['curr_step_statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=curr_tensor_stats) + stats_info['prev_step_statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=prev_tensor_stats) + stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats) + return stats_info + def get_tensor_statistics(self, tensor_name, node_type): """ Get Tensor statistics. @@ -378,6 +387,6 @@ class TensorHandler(StreamHandlerBase): """ res = {} tensor = self._get_tensor(tensor_name, node_type) - if tensor: + if tensor and not tensor.empty: res = tensor.get_tensor_statistics() return res diff --git a/mindinsight/debugger/stream_operator/tensor_detail_info.py b/mindinsight/debugger/stream_operator/tensor_detail_info.py index 47d84b52..32b48ed5 100644 --- a/mindinsight/debugger/stream_operator/tensor_detail_info.py +++ b/mindinsight/debugger/stream_operator/tensor_detail_info.py @@ -15,13 +15,14 @@ """This module is aimed to provide with tensor detail info.""" from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError from mindinsight.debugger.common.log import LOGGER as log -from mindinsight.debugger.common.utils import Streams +from mindinsight.debugger.common.utils import Streams, create_view_event_from_tensor_basic_info class TensorDetailInfo: """Manage tensor detail information.""" def __init__(self, cache): + self._put_command = cache.put_command self._tensor_stream = cache.get_stream_handler(Streams.TENSOR) self._graph_stream = cache.get_stream_handler(Streams.GRAPH) self._hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT) @@ -47,7 +48,7 @@ class TensorDetailInfo: Get the graph related to specific tensor. Args: - tensor_name (str): The name of tensor. Format like {node_name}:{slot}. + tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. graph_name (str): The graph name. Returns: @@ -70,12 +71,16 @@ class TensorDetailInfo: self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) graph = self._graph_stream.get_tensor_graph(tensor_name, graph_name) # add watchpoint hits info and statistics info for each tensor in tensor graph. + # record missing tensor basic info nodes = graph.get('graph', {}).get('nodes', []) + missing_tensors = [] for node in nodes: node['graph_name'] = graph_name for slot_info in node.get('slots', []): self._add_watchpoint_hit_info(slot_info, node) - self._add_statistic_info(slot_info, node) + self._add_statistic_info(slot_info, node, missing_tensors) + # query missing tensor values from client + self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) return graph def _add_watchpoint_hit_info(self, slot_info, node): @@ -89,17 +94,38 @@ class TensorDetailInfo: tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name)) - def _add_statistic_info(self, slot_info, node): + def _add_statistic_info(self, slot_info, node, missing_tensors): """ Get the watchpoint that the tensor hit. Args: slot_info (dict): Slot object. node (dict): Node object. + missing_tensors (list[TensorBasicInfo]): List of missing tensor infos. """ tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')]) node_type = node.get('type') slot_info['statistics'] = self._tensor_stream.get_tensor_statistics(tensor_name, node_type) + if not slot_info.get('statistics'): + log.debug("Get missing tensor basic infos for %s", tensor_name) + cur_missing_tensors = self._tensor_stream.get_missing_tensor_info(tensor_name, node_type) + missing_tensors.extend(cur_missing_tensors) + + def _ask_for_missing_tensor_value(self, missing_tensors, tensor_name, graph_name): + """ + Send view command to client to query for missing tensor values. + + Args: + missing_tensors (list[TensorBasicInfo]): List of missing tensor basic infos. + tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. + graph_name (str): The graph name. + """ + if not missing_tensors: + return + log.debug("Ask for tensor value for: %s", missing_tensors) + view_cmd = create_view_event_from_tensor_basic_info(missing_tensors) + self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name}) + log.debug("Send view cmd for tensor-graphs.") def get_tensor_watch_points(self, tensor_name, graph_name): """ diff --git a/mindinsight/debugger/stream_operator/training_control_operator.py b/mindinsight/debugger/stream_operator/training_control_operator.py new file mode 100644 index 00000000..acadb3db --- /dev/null +++ b/mindinsight/debugger/stream_operator/training_control_operator.py @@ -0,0 +1,273 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""This module is aimed to deal with controlling commands.""" +import enum + +from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \ + DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError +from mindinsight.debugger.common.log import LOGGER as log +from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type +from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD +from mindinsight.utils.exceptions import MindInsightException + + +@enum.unique +class ControlTypeEnum(enum.Enum): + """Control Type.""" + CONTINUE = 'continue' # continue to run training + PAUSE = 'pause' # suspend training + TERMINATE = 'terminate' # terminate training + + +class TrainingControlOperator: + """Control training operator.""" + # max step number should be less than int32 + _MAX_STEP_NUM = 2 ** 31 - 1 + + def __init__(self, cache_store): + self._cache_store = cache_store + self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) + self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) + self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) + + @staticmethod + def validate_mode(mode): + """Validate mode.""" + enum_members = [item.value for item in ControlTypeEnum] + if mode not in enum_members: + log.error("Invalid control mode %s", mode) + raise DebuggerParamValueError("Invalid control mode.") + + def control(self, mode, params): + """ + Control the training process. + + Args: + mode (str): Acceptable control command, including `continue`, + `pause` and `terminate`. + params (dict): The control params. + + - level (str): The control granularity, `node` level or `step` level. + Default: `step`. + - steps (int): Specify the steps that training should run. + Used when `level` is `step`. + - name (str): Specify the name of the node. Used when `level` is `node`. + - graph_name (str): The graph name. + + Returns: + dict, the response. + """ + if mode == ControlTypeEnum.CONTINUE.value: + reply = self.continue_training(params) + else: + mode_mapping = { + ControlTypeEnum.PAUSE.value: self.pause_training, + ControlTypeEnum.TERMINATE.value: self.terminate_training + } + reply = mode_mapping.get(mode)() + return reply + + def continue_training(self, params): + """ + Send RunCMD to MindSpore. + + Args: + params (dict): The control params. + + Returns: + dict, metadata info. + """ + metadata_stream = self._metadata_stream + if metadata_stream.state != ServerStatus.WAITING.value: + self._cache_store.put_data(metadata_stream.get()) + log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) + raise DebuggerContinueError( + "MindSpore is not ready to run or is running currently." + ) + metadata_stream.state = ServerStatus.RUNNING.value + try: + self._validate_continue_params(params) + event = self._construct_run_event(params) + self._send_watchpoints() + self._cache_store.put_command(event) + except MindInsightException as err: + log.error("Failed to send run event.") + log.exception(err) + metadata_stream.state = ServerStatus.WAITING.value + raise DebuggerContinueError("Failed to send run command.") + else: + metadata_stream.enable_recheck = False + log.debug("Send the RunCMD to command queue.") + return metadata_stream.get(['state', 'enable_recheck']) + + def _validate_continue_params(self, params): + """ + Validate continue params. + + Args: + params (dict): The control params. + + - level (str): The control granularity, `node`, `step` or `recheck` level. + Default: `step`. + - steps (int): Specify the steps that training should run. + Used when `level` is `step`. + - name (str): Specify the name of the node. Used when `level` is `node`. + - graph_name (str): The graph name. + + Raises: + DebuggerParamValueError: Params are invalid. + DebuggerStepNumError: Step number are invalid. + """ + # validate level + level = params.get('level', 'step') + if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]: + log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level) + raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.") + + # validate steps + step_num = params.get('steps', 1) + if not isinstance(step_num, int) or not (step_num == -1 or 0 < step_num <= self._MAX_STEP_NUM): + log.error("Invalid step value. Step number should be integer and in [1, 2^31 - 1] or -1.") + raise DebuggerStepNumError + + # validate node name + if level == RunLevel.NODE.value: + node_name = params.get('name') + graph_name = params.get('graph_name') + self._validate_continue_node_name(node_name, graph_name) + + def _validate_continue_node_name(self, node_name, graph_name): + """Validate if the node is a leaf node.""" + if not node_name: + return + node_type = self._graph_stream.get_node_type(node_name, graph_name) + if is_scope_type(node_type): + log.error("Scope type node has no tensor history.") + raise DebuggerParamValueError("Invalid leaf node name.") + + def _construct_run_event(self, params): + """ + Construct run cmd from input control params. + + Args: + params (dict): The control params. + + - level (str): The control granularity, `node`, `step` or `recheck` level. + Default: `step`. + - steps (int): Specify the steps that training should run. + Used when `level` is `step`. + - name (str): Specify the name of the node. Used when `level` is `node`. + - graph_name (str): The graph name. + + Returns: + EventReply, control event with run command. + """ + level = params.get('level', 'step') + # construct run command events + event = get_ack_reply() + if level == 'step': + steps = params.get('steps', 1) + run_cmd = RunCMD(run_level='step', run_steps=steps) + elif level == 'node': + name = params.get('name', '') + graph_name = params.get('graph_name') + if name: + name = self._cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name) + run_cmd = RunCMD(run_level='node', node_name=name) + else: + run_cmd = RunCMD(run_level='recheck') + + event.run_cmd.CopyFrom(run_cmd) + log.debug("Construct run event. %s", event) + return event + + def _send_watchpoints(self): + """Send watchpoints to client.""" + set_commands = self._watchpoint_stream.get_pending_commands(self._graph_stream) + if not set_commands: + return + for set_cmd in set_commands: + event = get_ack_reply() + event.set_cmd.CopyFrom(set_cmd) + self._cache_store.put_command(event) + log.debug("Send SetCMD to MindSpore. %s", event) + self._watchpoint_stream.sync_set_cmd(set_commands) + + def pause_training(self): + """ + Pause the training. + + Returns: + dict, metadata info. + """ + metadata_stream = self._metadata_stream + if metadata_stream.state != ServerStatus.RUNNING.value: + self._cache_store.put_data(metadata_stream.get()) + log.error("The MindSpore is not running.") + raise DebuggerPauseError("The MindSpore is not running.") + metadata_stream.state = 'waiting' + event = get_ack_reply() + event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) + self._cache_store.put_command(event) + metadata_stream.enable_recheck = False + log.debug("Send the Pause command") + return metadata_stream.get(['state', 'enable_recheck']) + + def terminate_training(self): + """ + Terminate the training. + + Returns: + dict, metadata info. + """ + metadata_stream = self._metadata_stream + metadata_stream.state = 'pending' + self._cache_store.clean_data() + self._cache_store.clean_command() + event = get_ack_reply() + event.exit = True + self._cache_store.put_command(event) + metadata_stream.enable_recheck = False + log.debug("Send the ExitCMD.") + return metadata_stream.get(['state', 'enable_recheck']) + + def recheck(self): + """ + Recheck all watchpoints. + + Returns: + dict, metadata info. + """ + metadata_stream = self._metadata_stream + # validate backend status is able to recheck watchpoint + if not metadata_stream.enable_recheck: + log.error("Recheck is not available.") + raise DebuggerRecheckError("Recheck is not available.") + metadata_stream.state = ServerStatus.RUNNING.value + metadata_stream.enable_recheck = False + # send updated watchpoint and recheck command + try: + event = self._construct_run_event({'level': 'recheck'}) + self._send_watchpoints() + self._cache_store.put_command(event) + except MindInsightException as err: + log.error("Failed to send recheck event.") + log.exception(err) + metadata_stream.state = ServerStatus.WAITING.value + metadata_stream.enable_recheck = True + raise DebuggerContinueError("Failed to send recheck command.") + else: + log.debug("Send the recheck to command queue.") + return metadata_stream.get(['state', 'enable_recheck']) diff --git a/mindinsight/utils/tensor.py b/mindinsight/utils/tensor.py index 42f6e3a8..a04adc07 100644 --- a/mindinsight/utils/tensor.py +++ b/mindinsight/utils/tensor.py @@ -316,6 +316,8 @@ class TensorUtils: Returns: dict, overall statistics. """ + if not overall_stats: + return {} res = { "overall_max": float(overall_stats.max), "overall_min": float(overall_stats.min), diff --git a/tests/st/func/debugger/expect_results/restful_results/compare_tensors.json b/tests/st/func/debugger/expect_results/restful_results/compare_tensors.json index 1574eb13..2dd9e11c 100644 --- a/tests/st/func/debugger/expect_results/restful_results/compare_tensors.json +++ b/tests/st/func/debugger/expect_results/restful_results/compare_tensors.json @@ -17,13 +17,6 @@ ] ], "curr_step_statistics": { - "max": 6.0, - "min": 1.0, - "avg": 3.5, - "count": 6, - "nan_count": 0, - "neg_inf_count": 0, - "pos_inf_count": 0, "overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, @@ -36,13 +29,6 @@ "overall_pos_zero_count": 6.0 }, "prev_step_statistics": { - "max": 6.0, - "min": 1.0, - "avg": 3.5, - "count": 6, - "nan_count": 0, - "neg_inf_count": 0, - "pos_inf_count": 0, "overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, @@ -55,13 +41,6 @@ "overall_pos_zero_count": 6.0 }, "statistics": { - "max": 0.0, - "min": 0.0, - "avg": 0.0, - "count": 6, - "nan_count": 0, - "neg_inf_count": 0, - "pos_inf_count": 0, "overall_max": 0.0, "overall_min": 0.0, "overall_avg": 0.0, diff --git a/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json b/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json index a7a534b4..2f67ce50 100644 --- a/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json +++ b/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json @@ -1 +1,29 @@ -{"tensor_value": {"full_name": "Default/TransData-op99:0", "step": 1, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "statistics": {"max": 6.0, "min": 5.0, "avg": 5.5, "count": 2, "nan_count": 0, "neg_inf_count": 0, "pos_inf_count": 0, "overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, "overall_count": 6, "overall_nan_count": 0, "overall_neg_inf_count": 0, "overall_pos_inf_count": 0, "overall_zero_count": 0.0, "overall_neg_zero_count": 0.0, "overall_pos_zero_count": 6.0}, "value": [5.0, 6.0], "name": "Default/TransData-op99:0"}} \ No newline at end of file +{ + "tensor_value": { + "full_name": "Default/TransData-op99:0", + "step": 1, + "dtype": "DT_FLOAT32", + "shape": [ + 2, + 3 + ], + "has_prev_step": false, + "statistics": { + "overall_max": 6.0, + "overall_min": 1.0, + "overall_avg": 3.5, + "overall_count": 6, + "overall_nan_count": 0, + "overall_neg_inf_count": 0, + "overall_pos_inf_count": 0, + "overall_zero_count": 0.0, + "overall_neg_zero_count": 0.0, + "overall_pos_zero_count": 6.0 + }, + "value": [ + 5.0, + 6.0 + ], + "name": "Default/TransData-op99:0" + } +} \ No newline at end of file diff --git a/tests/st/func/debugger/mock_ms_client.py b/tests/st/func/debugger/mock_ms_client.py index 9baa1aa6..dd1d623d 100644 --- a/tests/st/func/debugger/mock_ms_client.py +++ b/tests/st/func/debugger/mock_ms_client.py @@ -217,5 +217,5 @@ class MockDebuggerClientThread: return self._debugger_client_thread def __exit__(self, exc_type, exc_val, exc_tb): - self._debugger_client_thread.join(timeout=3) + self._debugger_client_thread.join(timeout=2) self._debugger_client.flag = False diff --git a/tests/ut/debugger/stream_handler/test_metadata_handler.py b/tests/ut/debugger/stream_handler/test_metadata_handler.py index dda5c5dc..5f041a27 100644 --- a/tests/ut/debugger/stream_handler/test_metadata_handler.py +++ b/tests/ut/debugger/stream_handler/test_metadata_handler.py @@ -17,6 +17,7 @@ from mindinsight.debugger.common.utils import ServerStatus from mindinsight.debugger.stream_handler.metadata_handler import MetadataHandler from mindinsight.debugger.proto.debug_grpc_pb2 import Metadata + class TestMetadataHandler: """test class for MetadataHandler""" def setup_method(self): diff --git a/tests/ut/debugger/stream_handler/test_tensor_handler.py b/tests/ut/debugger/stream_handler/test_tensor_handler.py index 33d321bd..50ce2d9f 100644 --- a/tests/ut/debugger/stream_handler/test_tensor_handler.py +++ b/tests/ut/debugger/stream_handler/test_tensor_handler.py @@ -40,7 +40,7 @@ class TestTensorHandler: def test_get_tensor_value_by_name_none(self): """Test get_tensor_value_by_name.""" - res = self.tensor_handler.get_tensor_value_by_name('tensor_name', True) + res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', True) assert res is None @mock.patch.object(log, "error") diff --git a/tests/ut/debugger/stream_operator/__init__.py b/tests/ut/debugger/stream_operator/__init__.py new file mode 100644 index 00000000..b9a89c12 --- /dev/null +++ b/tests/ut/debugger/stream_operator/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test for debugger stream operator.""" diff --git a/tests/ut/debugger/stream_operator/test_training_control_operator.py b/tests/ut/debugger/stream_operator/test_training_control_operator.py new file mode 100644 index 00000000..43de2254 --- /dev/null +++ b/tests/ut/debugger/stream_operator/test_training_control_operator.py @@ -0,0 +1,66 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Function: + Test debugger training control operator. +Usage: + pytest tests/ut/debugger/stream_operator/test_training_control_operator.py +""" +from unittest import mock + +import pytest + +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError +from mindinsight.debugger.debugger_cache import DebuggerCache +from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD +from mindinsight.debugger.stream_handler import GraphHandler, MetadataHandler +from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator + + +class TestTrainingControlOperator: + """Test debugger server.""" + + @classmethod + def setup_class(cls): + """Initialize for test class.""" + cls._server = None + + def setup_method(self): + """Prepare debugger server object.""" + cache_store = DebuggerCache() + cache_store.initialize() + self._server = TrainingControlOperator(cache_store) + + @mock.patch.object(GraphHandler, 'get_node_type') + def test_validate_leaf_name(self, *args): + """Test validate leaf name.""" + args[0].return_value = 'name_scope' + with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): + self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') + + @pytest.mark.parametrize('mode, cur_state, state', [ + ('continue', 'waiting', 'running'), + ('pause', 'running', 'waiting'), + ('terminate', 'waiting', 'pending')]) + def test_control(self, mode, cur_state, state): + """Test control request.""" + with mock.patch.object(MetadataHandler, 'state', cur_state): + res = self._server.control(mode=mode, params={}) + assert res == {'metadata': {'enable_recheck': False, 'state': state}} + + def test_construct_run_event(self): + """Test construct run event.""" + res = self._server._construct_run_event({'level': 'node'}) + assert res.run_cmd == RunCMD(run_level='node', node_name='') diff --git a/tests/ut/debugger/test_debugger_grpc_server.py b/tests/ut/debugger/test_debugger_grpc_server.py index 2ec883f5..eca209d6 100644 --- a/tests/ut/debugger/test_debugger_grpc_server.py +++ b/tests/ut/debugger/test_debugger_grpc_server.py @@ -68,7 +68,7 @@ class MockDataGenerator: view_event = get_ack_reply() ms_tensor = view_event.view_cmd.tensors.add() ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0' - event = {'view_cmd': view_event, 'node_name': 'mock_node_name'} + event = {'view_cmd': view_event, 'node_name': 'mock_node_name', 'graph_name': 'mock_graph_name'} return event @staticmethod @@ -180,10 +180,10 @@ class TestDebuggerGrpcServer: def test_deal_with_old_command_with_view_cmd(self, *args): """Test deal with view command.""" cmd = MockDataGenerator.get_view_cmd() - args[1].return_value = ('0', cmd) + args[1].return_value = ('0', cmd.copy()) res = self._server._deal_with_old_command() - assert res == cmd.get('view_cmd') - expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True} + assert res == cmd.pop('view_cmd') + expect_received_view_cmd = {'node_info': cmd, 'wait_for_tensor': True} assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd @mock.patch.object(DebuggerCache, 'get_command') @@ -201,7 +201,7 @@ class TestDebuggerGrpcServer: """Test wait for run command.""" pause_cmd = MockDataGenerator.get_run_cmd(steps=0) empty_view_cmd = MockDataGenerator.get_view_cmd() - empty_view_cmd.pop('node_name') + empty_view_cmd.pop('view_cmd') run_cmd = MockDataGenerator.get_run_cmd(steps=2) args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)] setattr(self._server, '_status', ServerStatus.WAITING) diff --git a/tests/ut/debugger/test_debugger_server.py b/tests/ut/debugger/test_debugger_server.py index 1462a964..841543a1 100644 --- a/tests/ut/debugger/test_debugger_server.py +++ b/tests/ut/debugger/test_debugger_server.py @@ -32,7 +32,6 @@ from mindinsight.debugger.common.utils import Streams from mindinsight.debugger.debugger_cache import DebuggerCache from mindinsight.debugger.debugger_server import DebuggerServer from mindinsight.debugger.debugger_server import grpc_server_base -from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ TensorHandler from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history @@ -154,13 +153,6 @@ class TestDebuggerServer: res = self._server.retrieve_tensor_history('mock_node_name') compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json') - @mock.patch.object(GraphHandler, 'get_node_type') - def test_validate_leaf_name(self, *args): - """Test validate leaf name.""" - args[0].return_value = 'name_scope' - with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): - self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') - @mock.patch.object(TensorHandler, 'get') @mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name') def test_retrieve_tensor_value(self, *args): @@ -187,7 +179,6 @@ class TestDebuggerServer: res = self._server._retrieve_watchpoint({'watch_point_id': 1}) assert res == mock_watchpoint - @mock.patch.object(DebuggerServer, '_validate_continue_node_name') @mock.patch.object(DebuggerServer, '_get_tensor_history') @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}}) def test_retrieve_watchpoint_hit(self, *args): @@ -238,18 +229,3 @@ class TestDebuggerServer: args[0].return_value = None res = self._server.delete_watchpoint(1) assert res == {'metadata': {'enable_recheck': True, 'state': 'waiting'}} - - @pytest.mark.parametrize('mode, cur_state, state', [ - ('continue', 'waiting', 'running'), - ('pause', 'running', 'waiting'), - ('terminate', 'waiting', 'pending')]) - def test_control(self, mode, cur_state, state): - """Test control request.""" - with mock.patch.object(MetadataHandler, 'state', cur_state): - res = self._server.control({'mode': mode}) - assert res == {'metadata': {'enable_recheck': False, 'state': state}} - - def test_construct_run_event(self): - """Test construct run event.""" - res = self._server._construct_run_event({'level': 'node'}) - assert res.run_cmd == RunCMD(run_level='node', node_name='')