| @@ -27,7 +27,6 @@ class DebuggerErrors(DebuggerErrorCodes): | |||
| """Debugger error codes.""" | |||
| PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK | |||
| PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK | |||
| STEP_NUM_ERROR = 2 | _PARAM_ERROR_MASK | |||
| NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR | |||
| @@ -40,6 +39,8 @@ class DebuggerErrors(DebuggerErrorCodes): | |||
| PAUSE_ERROR = 4 | _DEBUGGER_RUNNING_ERROR | |||
| COMPARE_TENSOR_ERROR = 5 | _DEBUGGER_RUNNING_ERROR | |||
| RECHECK_ERROR = 6 | _DEBUGGER_RUNNING_ERROR | |||
| TENSOR_GRAPH_ERROR = 7 | _DEBUGGER_RUNNING_ERROR | |||
| TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR | |||
| @unique | |||
| @@ -56,3 +57,5 @@ class DebuggerErrorMsg(Enum): | |||
| CONTINUE_ERROR = "Continue debugging failed. {}" | |||
| PAUSE_ERROR = "Pause debugging failed. {}" | |||
| RECHECK_ERROR = "Recheck failed. {}" | |||
| TENSOR_GRAPH_ERROR = "Get tensor graphs failed." | |||
| TENSOR_HIT_ERROR = "Get tensor hits failed." | |||
| @@ -146,3 +146,25 @@ class DebuggerStepNumError(MindInsightException): | |||
| message="The type of step number should be int32.", | |||
| http_code=400 | |||
| ) | |||
| class DebuggerTensorGraphError(MindInsightException): | |||
| """The error about comparing tensors.""" | |||
| def __init__(self): | |||
| super(DebuggerTensorGraphError, self).__init__( | |||
| error=DebuggerErrors.TENSOR_GRAPH_ERROR, | |||
| message=DebuggerErrorMsg.TENSOR_GRAPH_ERROR.value, | |||
| http_code=400 | |||
| ) | |||
| class DebuggerTensorHitError(MindInsightException): | |||
| """The error about comparing tensors.""" | |||
| def __init__(self): | |||
| super(DebuggerTensorHitError, self).__init__( | |||
| error=DebuggerErrors.TENSOR_HIT_ERROR, | |||
| message=DebuggerErrorMsg.TENSOR_HIT_ERROR.value, | |||
| http_code=400 | |||
| ) | |||
| @@ -115,29 +115,29 @@ def wrap_reply_response(error_code=None, error_message=None): | |||
| return reply | |||
| def create_view_event_from_tensor_history(tensor_history): | |||
| def create_view_event_from_tensor_basic_info(tensors_info): | |||
| """ | |||
| Create view event reply according to tensor names. | |||
| Args: | |||
| tensor_history (list[dict]): The list of tensor history. Each element has keys: | |||
| `name`, `node_type`. | |||
| tensors_info (list[TensorBasicInfo]): The list of TensorBasicInfo. Each element has keys: | |||
| `full_name`, `node_type`, `iter`. | |||
| Returns: | |||
| EventReply, the event reply with view cmd. | |||
| """ | |||
| view_event = get_ack_reply() | |||
| for tensor_info in tensor_history: | |||
| node_type = tensor_info.get('node_type') | |||
| for tensor_info in tensors_info: | |||
| node_type = tensor_info.node_type | |||
| if node_type == NodeTypeEnum.CONST.value: | |||
| continue | |||
| truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value | |||
| tensor_name = tensor_info.get('full_name', '') | |||
| truncate_tag = node_type == NodeTypeEnum.PARAMETER.value | |||
| tensor_name = tensor_info.full_name | |||
| # create view command | |||
| ms_tensor = view_event.view_cmd.tensors.add() | |||
| ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1) | |||
| ms_tensor.truncate = truncate_tag | |||
| ms_tensor.iter = 'prev' if tensor_info.get('iter') else '' | |||
| ms_tensor.iter = tensor_info.iter | |||
| return view_event | |||
| @@ -159,15 +159,15 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def _send_received_tensor_tag(self): | |||
| """Send received_finish_tag.""" | |||
| node_name = self._received_view_cmd.get('node_name') | |||
| if not node_name or self._received_view_cmd.get('wait_for_tensor'): | |||
| node_info = self._received_view_cmd.get('node_info') | |||
| if not node_info or self._received_view_cmd.get('wait_for_tensor'): | |||
| return | |||
| metadata = self._cache_store.get_stream_handler(Streams.METADATA).get(['step', 'state']) | |||
| ret = {'receive_tensor': {'node_name': node_name}} | |||
| ret = {'receive_tensor': node_info.copy()} | |||
| ret.update(metadata) | |||
| self._cache_store.put_data(ret) | |||
| self._received_view_cmd.clear() | |||
| log.debug("Send receive tensor flag for %s", node_name) | |||
| log.debug("Send receive tensor flag for %s", node_info) | |||
| def _send_watchpoint_hit_flag(self): | |||
| """Send Watchpoint hit flag.""" | |||
| @@ -281,14 +281,26 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| return event | |||
| def _deal_with_view_cmd(self, event): | |||
| """Deal with view cmd.""" | |||
| view_cmd = event.get('view_cmd') | |||
| node_name = event.get('node_name') | |||
| log.debug("Receive view cmd for node: %s.", node_name) | |||
| if not (view_cmd and node_name): | |||
| """ | |||
| Deal with view cmd. | |||
| Args: | |||
| event (dict): View command params. | |||
| - view_cmd (EventReply): EventReply with view command. | |||
| - node_name (str): The center node name for view command. | |||
| - tensor_name (str): The center tensor name for view command. | |||
| - graph_name (str): The graph name of center node. | |||
| Returns: | |||
| EventReply, view command to be sent to client. | |||
| """ | |||
| view_cmd = event.pop('view_cmd', None) | |||
| log.debug("Receive view cmd for node: %s.", event) | |||
| if not (view_cmd and event): | |||
| log.debug("Invalid view command. Ignore it.") | |||
| return None | |||
| self._received_view_cmd['node_name'] = node_name | |||
| self._received_view_cmd['node_info'] = event | |||
| self._received_view_cmd['wait_for_tensor'] = True | |||
| return view_cmd | |||
| @@ -395,6 +407,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| if tensor.finished: | |||
| update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) | |||
| if self._received_view_cmd.get('wait_for_tensor') and update_flag: | |||
| # update_flag is used to avoid querying empty tensors again | |||
| self._received_view_cmd['wait_for_tensor'] = False | |||
| log.debug("Set wait for tensor flag to False.") | |||
| tensor_construct = [] | |||
| @@ -16,35 +16,33 @@ | |||
| import signal | |||
| from concurrent import futures | |||
| from threading import Thread | |||
| import grpc | |||
| from mindinsight.conditionmgr.conditionmgr import ConditionMgr | |||
| from mindinsight.conditionmgr.common.utils import NodeBasicInfo | |||
| from mindinsight.conditionmgr.condition import ConditionContext, ConditionIdEnum | |||
| from mindinsight.conditionmgr.conditionmgr import ConditionMgr | |||
| from mindinsight.conditionmgr.recommender import recommend_watchpoints | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.datavisual.utils.tools import to_float | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | |||
| DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \ | |||
| DebuggerCompareTensorError, DebuggerRecheckError, DebuggerStepNumError | |||
| DebuggerDeleteWatchPointError, DebuggerCompareTensorError, DebuggerTensorGraphError, \ | |||
| DebuggerTensorHitError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||
| create_view_event_from_tensor_history, Streams, is_scope_type, RunLevel | |||
| from mindinsight.conditionmgr.common.utils import NodeBasicInfo | |||
| from mindinsight.debugger.common.utils import ServerStatus, \ | |||
| create_view_event_from_tensor_basic_info, Streams | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||
| from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator | |||
| from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR | |||
| class DebuggerServer: | |||
| """The server manager of debugger.""" | |||
| # max step number should be less than int32 | |||
| _MAX_STEP_NUM = 2 ** 31 - 1 | |||
| def __init__(self, grpc_port=None): | |||
| self.grpc_port = grpc_port | |||
| @@ -355,7 +353,7 @@ class DebuggerServer: | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| tensor_history = graph_stream.get_tensor_history(node_name, graph_name) | |||
| # add tensor value for tensor history | |||
| self._add_tensor_value_for_tensor_history(tensor_history, node_name) | |||
| self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name) | |||
| # add hit label for tensor history | |||
| watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| watchpoint_hit_stream.update_tensor_history(tensor_history) | |||
| @@ -364,13 +362,14 @@ class DebuggerServer: | |||
| tensor_history.update(metadata) | |||
| return tensor_history | |||
| def _add_tensor_value_for_tensor_history(self, tensor_history, node_name): | |||
| def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name): | |||
| """ | |||
| Add tensor value for_tensor_history and send ViewCMD if tensor value missed. | |||
| Args: | |||
| tensor_history (list[dict]): A list of tensor info, including name and type. | |||
| node_name (str): The UI node name. | |||
| graph_name (str): The graph name. Default: None. | |||
| Returns: | |||
| dict, the tensor info. | |||
| @@ -378,8 +377,8 @@ class DebuggerServer: | |||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||
| missed_tensors = tensor_stream.update_tensor_history(tensor_history) | |||
| if missed_tensors: | |||
| view_cmd = create_view_event_from_tensor_history(missed_tensors) | |||
| self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name}) | |||
| view_cmd = create_view_event_from_tensor_basic_info(missed_tensors) | |||
| self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name}) | |||
| log.debug("Send view cmd.") | |||
| def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False): | |||
| @@ -679,189 +678,10 @@ class DebuggerServer: | |||
| dict, the response. | |||
| """ | |||
| log.info("Receive control request: %s.", params) | |||
| mode = params.get('mode') | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| if mode == 'continue': | |||
| reply = self._continue(metadata_stream, params) | |||
| elif mode in ['pause', 'terminate']: | |||
| mode_mapping = { | |||
| 'pause': self._pause, | |||
| 'terminate': self._terminate | |||
| } | |||
| reply = mode_mapping.get(mode)(metadata_stream) | |||
| else: | |||
| log.error("Invalid control mode %s", mode) | |||
| raise DebuggerParamValueError("Invalid control mode.") | |||
| return reply | |||
| def _continue(self, metadata_stream, params): | |||
| """ | |||
| Send RunCMD to MindSpore. | |||
| Args: | |||
| metadata_stream (MetadataHandler): The metadata_handler | |||
| params (dict): The control params. | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| self.cache_store.put_data(metadata_stream.get()) | |||
| log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) | |||
| raise DebuggerContinueError( | |||
| "MindSpore is not ready to run or is running currently." | |||
| ) | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| try: | |||
| self._validate_continue_params(params) | |||
| event = self._construct_run_event(params) | |||
| self._send_watchpoints() | |||
| self.cache_store.put_command(event) | |||
| except MindInsightException as err: | |||
| log.error("Failed to send run event.") | |||
| log.exception(err) | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| raise DebuggerContinueError("Failed to send run command.") | |||
| else: | |||
| metadata_stream.enable_recheck = False | |||
| log.debug("Send the RunCMD to command queue.") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _validate_continue_params(self, params): | |||
| """ | |||
| Validate continue params. | |||
| Args: | |||
| params (dict): The control params. | |||
| - level (str): The control granularity, `node`, `step` or `recheck` level. | |||
| Default: `step`. | |||
| - steps (int): Specify the steps that training should run. | |||
| Used when `level` is `step`. | |||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||
| - graph_name (str): The graph name. | |||
| Raises: | |||
| DebuggerParamValueError: Params are invalid. | |||
| """ | |||
| # validate level | |||
| level = params.get('level', 'step') | |||
| if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]: | |||
| log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level) | |||
| raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.") | |||
| # validate steps | |||
| step_num = params.get('steps', 1) | |||
| if not isinstance(step_num, int) or not (step_num == -1 or 0 < step_num <= self._MAX_STEP_NUM): | |||
| log.error("Invalid step value. Step number should be integer and in [1, 2^31 - 1] or -1.") | |||
| raise DebuggerStepNumError | |||
| # validate node name | |||
| if level == RunLevel.NODE.value: | |||
| node_name = params.get('name') | |||
| graph_name = params.get('graph_name') | |||
| self._validate_continue_node_name(node_name, graph_name) | |||
| def _construct_run_event(self, params): | |||
| """ | |||
| Construct run cmd from input control params. | |||
| Args: | |||
| params (dict): The control params. | |||
| - level (str): The control granularity, `node`, `step` or `recheck` level. | |||
| Default: `step`. | |||
| - steps (int): Specify the steps that training should run. | |||
| Used when `level` is `step`. | |||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||
| - graph_name (str): The graph name. | |||
| Returns: | |||
| EventReply, control event with run command. | |||
| """ | |||
| level = params.get('level', 'step') | |||
| # construct run command events | |||
| event = get_ack_reply() | |||
| if level == 'step': | |||
| steps = params.get('steps', 1) | |||
| run_cmd = RunCMD(run_level='step', run_steps=steps) | |||
| elif level == 'node': | |||
| name = params.get('name', '') | |||
| graph_name = params.get('graph_name') | |||
| if name: | |||
| name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name) | |||
| run_cmd = RunCMD(run_level='node', node_name=name) | |||
| else: | |||
| run_cmd = RunCMD(run_level='recheck') | |||
| event.run_cmd.CopyFrom(run_cmd) | |||
| log.debug("Construct run event. %s", event) | |||
| return event | |||
| def _validate_continue_node_name(self, node_name, graph_name): | |||
| """Validate if the node is a leaf node.""" | |||
| if not node_name: | |||
| return | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| node_type = graph_stream.get_node_type(node_name, graph_name) | |||
| if is_scope_type(node_type): | |||
| log.error("Scope type node has no tensor history.") | |||
| raise DebuggerParamValueError("Invalid leaf node name.") | |||
| def _send_watchpoints(self): | |||
| """Set watchpoints.""" | |||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| set_commands = watchpoint_stream.get_pending_commands(self.cache_store.get_stream_handler(Streams.GRAPH)) | |||
| if set_commands: | |||
| for set_cmd in set_commands: | |||
| event = get_ack_reply() | |||
| event.set_cmd.CopyFrom(set_cmd) | |||
| self.cache_store.put_command(event) | |||
| watchpoint_stream.sync_set_cmd(set_commands) | |||
| log.debug("Send SetCMD to MindSpore. %s", event) | |||
| def _pause(self, metadata_stream): | |||
| """ | |||
| Pause the training. | |||
| Args: | |||
| metadata_stream (MetadataHandler): The metadata stream handler. | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| if metadata_stream.state != ServerStatus.RUNNING.value: | |||
| self.cache_store.put_data(metadata_stream.get()) | |||
| log.error("The MindSpore is not running.") | |||
| raise DebuggerPauseError("The MindSpore is not running.") | |||
| metadata_stream.state = 'waiting' | |||
| event = get_ack_reply() | |||
| event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) | |||
| self.cache_store.put_command(event) | |||
| metadata_stream.enable_recheck = False | |||
| log.debug("Send the Pause command") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _terminate(self, metadata_stream): | |||
| """ | |||
| Terminate the training. | |||
| Args: | |||
| metadata_stream (MetadataHandler): The metadata stream handler. | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| metadata_stream.state = 'pending' | |||
| self.cache_store.clean_data() | |||
| self.cache_store.clean_command() | |||
| event = get_ack_reply() | |||
| event.exit = True | |||
| self.cache_store.put_command(event) | |||
| metadata_stream.enable_recheck = False | |||
| log.debug("Send the ExitCMD.") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| mode = params.pop('mode', None) | |||
| training_controller = TrainingControlOperator(self.cache_store) | |||
| training_controller.validate_mode(mode) | |||
| return training_controller.control(mode, params) | |||
| def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False): | |||
| """ | |||
| @@ -904,27 +724,7 @@ class DebuggerServer: | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| # validate backend status is able to recheck watchpoint | |||
| if not metadata_stream.enable_recheck: | |||
| log.error("Recheck is not available.") | |||
| raise DebuggerRecheckError("Recheck is not available.") | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| metadata_stream.enable_recheck = False | |||
| # send updated watchpoint and recheck command | |||
| try: | |||
| event = self._construct_run_event({'level': 'recheck'}) | |||
| self._send_watchpoints() | |||
| self.cache_store.put_command(event) | |||
| except MindInsightException as err: | |||
| log.error("Failed to send recheck event.") | |||
| log.exception(err) | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| metadata_stream.enable_recheck = True | |||
| raise DebuggerContinueError("Failed to send run command.") | |||
| else: | |||
| log.debug("Send the recheck to command queue.") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| return TrainingControlOperator(self.cache_store).recheck() | |||
| def retrieve_tensor_graph(self, tensor_name, graph_name): | |||
| """ | |||
| @@ -937,6 +737,9 @@ class DebuggerServer: | |||
| Returns: | |||
| dict, tensor graph object. | |||
| """ | |||
| if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value: | |||
| log.error("Failed to get tensor graph the MindSpore is not in waiting state.") | |||
| raise DebuggerTensorGraphError | |||
| log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name) | |||
| tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name) | |||
| return tensor_graph_ops | |||
| @@ -952,6 +755,9 @@ class DebuggerServer: | |||
| Returns: | |||
| dict, tensor hit info. | |||
| """ | |||
| if self.cache_store.get_stream_handler(Streams.METADATA).state != ServerStatus.WAITING.value: | |||
| log.error("Failed to get tensor hits as the MindSpore is not in waiting state.") | |||
| raise DebuggerTensorHitError | |||
| log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name) | |||
| watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name) | |||
| return {'watch_points': watch_points} | |||
| @@ -130,6 +130,16 @@ class OpTensor(BaseTensor): | |||
| """The property of tensor stats.""" | |||
| return self._stats | |||
| @stats.setter | |||
| def stats(self, stats): | |||
| """ | |||
| Update tensor stats. | |||
| Args: | |||
| stats (Statistics): Instance of Statistics. | |||
| """ | |||
| self._stats = stats | |||
| @property | |||
| def tensor_comparison(self): | |||
| """The property of tensor_comparison.""" | |||
| @@ -167,15 +177,10 @@ class OpTensor(BaseTensor): | |||
| res = {} | |||
| # the type of tensor_value is one of None, np.ndarray or str | |||
| if isinstance(tensor_value, np.ndarray): | |||
| statistics = TensorUtils.get_statistics_from_tensor(tensor_value) | |||
| if not self.stats: | |||
| self.update_tensor_stats(TensorUtils.get_statistics_from_tensor(self.value)) | |||
| res['statistics'] = TensorUtils.get_statistics_dict(stats=statistics, overall_stats=self.stats) | |||
| res['value'] = tensor_value.tolist() | |||
| elif isinstance(tensor_value, str): | |||
| res['value'] = tensor_value | |||
| res['statistics'] = TensorUtils.get_overall_statistic_dict(self._stats) | |||
| res['statistics'] = self.get_tensor_statistics() | |||
| return res | |||
| def get_tensor_statistics(self): | |||
| @@ -185,9 +190,11 @@ class OpTensor(BaseTensor): | |||
| Returns: | |||
| dict, overall statistics. | |||
| """ | |||
| if not self._stats: | |||
| self._stats = TensorUtils.get_statistics_from_tensor(self.value) | |||
| statistics = TensorUtils.get_overall_statistic_dict(self._stats) | |||
| if self.empty: | |||
| return {} | |||
| if not self.stats: | |||
| self.stats = TensorUtils.get_statistics_from_tensor(self.value) | |||
| statistics = TensorUtils.get_overall_statistic_dict(self.stats) | |||
| return statistics | |||
| def update_tensor_comparisons(self, tensor_comparison): | |||
| @@ -200,16 +207,6 @@ class OpTensor(BaseTensor): | |||
| """ | |||
| self._tensor_comparison = tensor_comparison | |||
| def update_tensor_stats(self, stats): | |||
| """ | |||
| Update tensor stats. | |||
| Args: | |||
| stats (Statistics) instance of Statistics. | |||
| """ | |||
| self._stats = stats | |||
| def get_tensor_value_by_shape(self, shape=None): | |||
| """ | |||
| Get tensor value by shape. | |||
| @@ -467,7 +467,7 @@ class GraphHandler(StreamHandlerBase): | |||
| Get tensor graph according to node name. | |||
| Args: | |||
| tensor_name (str): Tensor name, format is "node_name:<node_value>". | |||
| tensor_name (str): Tensor name from UI, format is "node_name:slot". | |||
| graph_name (str): The relative graph_name of the node. Default: None. | |||
| Returns: | |||
| @@ -624,7 +624,6 @@ class GraphHandler(StreamHandlerBase): | |||
| graph_name = self.graph_names[0] | |||
| return graph_name | |||
| def _add_graph_scope_for_nodes(self, nodes, graph_name): | |||
| """ | |||
| Add graph scope for nodes. | |||
| @@ -13,6 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the tensor stream handler.""" | |||
| from collections import namedtuple | |||
| import numpy as np | |||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||
| @@ -23,6 +25,7 @@ from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| from mindinsight.utils.tensor import TensorUtils, TensorComparison | |||
| TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) | |||
| class TensorHandler(StreamHandlerBase): | |||
| """Metadata Handler.""" | |||
| @@ -170,7 +173,7 @@ class TensorHandler(StreamHandlerBase): | |||
| log.error("No tensor named %s at the step %s", name, step) | |||
| raise DebuggerParamValueError("No tensor named {}".format(name)) | |||
| tensor_info = tensor.get_full_info(shape) | |||
| self._update_has_prev_step_field(tensor_info, name, node_type, step) | |||
| self._update_has_prev_step_field(tensor_info, name, node_type) | |||
| return {'tensor_value': tensor_info} | |||
| def _get_tensor(self, tensor_name, node_type=None, step=None): | |||
| @@ -219,35 +222,46 @@ class TensorHandler(StreamHandlerBase): | |||
| tensor_name = tensor_info.get('full_name') | |||
| node_type = tensor_info.get('node_type') | |||
| basic_info = self._get_basic_info(tensor_name, node_type) | |||
| flag = self._update_has_prev_step_field(basic_info, tensor_name, node_type, self.cur_step) | |||
| if flag is False: | |||
| missed_tensor = tensor_info.copy() | |||
| missed_tensor['iter'] = 'prev' | |||
| missed_tensors.append(missed_tensor) | |||
| log.debug("Add previous view cmd for %s", tensor_name) | |||
| # add `has_prev_step` field to tensor basic info. | |||
| missing_tensor_infos = self._update_has_prev_step_field(basic_info, tensor_name, node_type) | |||
| if basic_info: | |||
| tensor_info.update(basic_info) | |||
| if basic_info.get('value') is None: | |||
| missed_tensors.append(tensor_info) | |||
| log.debug("Add view cmd for %s", tensor_name) | |||
| else: | |||
| missed_tensors.append(tensor_info) | |||
| log.debug("Add view cmd for %s", tensor_name) | |||
| if missing_tensor_infos: | |||
| missed_tensors.extend(missing_tensor_infos) | |||
| return missed_tensors | |||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type, step): | |||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): | |||
| """Update has_prev_step field in tensor info.""" | |||
| flag = None | |||
| cur_tensor_value = bool(tensor_info and tensor_info.get('value') is not None) | |||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||
| flag = self._get_prev_tensor_value_status(tensor_name, step) | |||
| if flag and cur_tensor_value: | |||
| tensor_info['has_prev_step'] = True | |||
| return flag | |||
| missing_tensor_infos = self.get_missing_tensor_info(tensor_name, node_type) | |||
| if not missing_tensor_infos and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0: | |||
| tensor_info['has_prev_step'] = True | |||
| return missing_tensor_infos | |||
| def _get_prev_tensor_value_status(self, tensor_name, step): | |||
| def get_missing_tensor_info(self, tensor_name, node_type): | |||
| """ | |||
| Get missing tensor infos. | |||
| Args: | |||
| tensor_name (str): The full name of Tensor. | |||
| node_type (str): The type of the relative node. | |||
| Returns: | |||
| list, list of missing tensor basic information. | |||
| """ | |||
| step = self.cur_step | |||
| missing_tensor_infos = [] | |||
| # check the current step value is missing | |||
| if self._is_tensor_value_missing(tensor_name, step): | |||
| missing_tensor_infos.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='')) | |||
| log.debug("Add current step view cmd for %s", tensor_name) | |||
| # check the previous step value is missing | |||
| if node_type == NodeTypeEnum.PARAMETER.value and self._is_tensor_value_missing(tensor_name, step - 1): | |||
| missing_tensor_infos.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='prev')) | |||
| log.debug("Add previous view cmd for %s", tensor_name) | |||
| return missing_tensor_infos | |||
| def _is_tensor_value_missing(self, tensor_name, step): | |||
| """ | |||
| Get the status of tensor value of previous step. | |||
| @@ -256,27 +270,25 @@ class TensorHandler(StreamHandlerBase): | |||
| step (int): The step of the tensor. | |||
| Returns: | |||
| Union[None, bool], the status of previous tensor value. If True, there is valid previous | |||
| tensor value. If False, the tensor value should be queried from client. | |||
| Union[None, bool], the status of tensor value. If False, there is valid | |||
| tensor value. If True, the tensor value should be queried from client. | |||
| If None, ignore. | |||
| """ | |||
| flag = None | |||
| # check if the tensor has previous step value. | |||
| prev_step = step - 1 | |||
| if prev_step < 0: | |||
| return flag | |||
| tensor = self._get_tensor(tensor_name, step=prev_step) | |||
| return bool(tensor and not tensor.empty) | |||
| def get_tensor_value_by_name(self, tensor_name, prev=False): | |||
| """Get tensor value by name in numpy type.""" | |||
| cur_step = self._cur_step | |||
| step = cur_step - 1 if prev else cur_step | |||
| if step < 0: | |||
| log.warning("%d step has no previous value for tensor: %s", cur_step, tensor_name) | |||
| return None | |||
| tensor = self._get_tensor(tensor_name, step=step) | |||
| return bool(not tensor or tensor.empty) | |||
| def get_valid_tensor_by_name(self, tensor_name, prev=False): | |||
| """Get tensor value by name in numpy type.""" | |||
| step = self.prev_step if prev else self.cur_step | |||
| if step < 0: | |||
| log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name) | |||
| return None | |||
| tensor = self._get_tensor(tensor_name, step=step) | |||
| if tensor and tensor.empty: | |||
| log.warning("%s has empty value.", tensor_name) | |||
| return None | |||
| return tensor | |||
| def clean_tensors(self, cur_step): | |||
| @@ -313,35 +325,29 @@ class TensorHandler(StreamHandlerBase): | |||
| Returns: | |||
| dict, the retrieved data. | |||
| """ | |||
| curr_tensor = self.get_tensor_value_by_name(tensor_name) | |||
| prev_tensor = self.get_tensor_value_by_name(tensor_name, prev=True) | |||
| curr_tensor = self.get_valid_tensor_by_name(tensor_name) | |||
| prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True) | |||
| if not (curr_tensor and prev_tensor): | |||
| log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) | |||
| raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " | |||
| f"{tensor_name} failed.") | |||
| curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape) | |||
| prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape) | |||
| # get tensor comparison basic info | |||
| tensor_info = curr_tensor.get_basic_info() | |||
| if isinstance(tensor_info, dict): | |||
| tensor_info.pop('has_prev_step') | |||
| tensor_info.pop('value') | |||
| tensor_info.pop('has_prev_step') | |||
| tensor_info.pop('value') | |||
| # calculate tensor comparision object | |||
| tensor_comparison = curr_tensor.tensor_comparison | |||
| if not tensor_comparison or tensor_comparison.tolerance != tolerance: | |||
| if isinstance(curr_tensor.value, np.ndarray) and isinstance(prev_tensor.value, np.ndarray): | |||
| if curr_tensor.value.shape != prev_tensor.value.shape: | |||
| raise DebuggerParamValueError("The shape of these two step tensors is not the same.") | |||
| tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance) | |||
| if not tensor_comparison: | |||
| stats = TensorUtils.get_statistics_from_tensor(tensor_diff) | |||
| tensor_comparison = TensorComparison(tolerance, stats, tensor_diff) | |||
| curr_tensor.update_tensor_comparisons(tensor_comparison) | |||
| else: | |||
| tensor_comparison.update(tolerance=tolerance, value=tensor_diff) | |||
| else: | |||
| raise DebuggerParamValueError("The type of tensor value should be numpy.ndarray.") | |||
| # the type of curr_tensor_slice is one of None, np.ndarray or str | |||
| if curr_tensor.value.shape != prev_tensor.value.shape: | |||
| raise DebuggerParamValueError("The shape of these two step tensors is not the same.") | |||
| tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance) | |||
| stats = TensorUtils.get_statistics_from_tensor(tensor_diff) | |||
| tensor_comparison = TensorComparison(tolerance, stats, tensor_diff) | |||
| curr_tensor.update_tensor_comparisons(tensor_comparison) | |||
| # calculate diff value | |||
| # the type of curr_tensor_slice is one of np.ndarray or str | |||
| if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray): | |||
| if not shape: | |||
| tensor_diff_slice = tensor_comparison.value | |||
| @@ -349,22 +355,25 @@ class TensorHandler(StreamHandlerBase): | |||
| tensor_diff_slice = tensor_comparison.value[shape] | |||
| result = np.stack([prev_tensor_slice, curr_tensor_slice, tensor_diff_slice], axis=-1) | |||
| tensor_info['diff'] = result.tolist() | |||
| stats = TensorUtils.get_statistics_from_tensor(tensor_diff_slice) | |||
| curr_tensor_stats = TensorUtils.get_statistics_from_tensor(curr_tensor.value) | |||
| curr_tensor_slice_stats = TensorUtils.get_statistics_from_tensor(curr_tensor_slice) | |||
| prev_tensor_stats = TensorUtils.get_statistics_from_tensor(prev_tensor.value) | |||
| prev_tensor_slice_stats = TensorUtils.get_statistics_from_tensor(prev_tensor_slice) | |||
| tensor_info['curr_step_statistics'] = TensorUtils.get_statistics_dict(stats=curr_tensor_slice_stats, | |||
| overall_stats=curr_tensor_stats) | |||
| tensor_info['prev_step_statistics'] = TensorUtils.get_statistics_dict(stats=prev_tensor_slice_stats, | |||
| overall_stats=prev_tensor_stats) | |||
| tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats=stats, | |||
| overall_stats=tensor_comparison.stats) | |||
| elif isinstance(curr_tensor_slice, str): | |||
| tensor_info['diff'] = curr_tensor_slice | |||
| # add comparision statistics | |||
| tensor_info.update(self._get_comparison_statistics(curr_tensor, prev_tensor)) | |||
| reply = {'tensor_value': tensor_info} | |||
| return reply | |||
| @staticmethod | |||
| def _get_comparison_statistics(curr_tensor, prev_tensor): | |||
| """Get comparison statistics.""" | |||
| stats_info = {} | |||
| diff_tensor_stats = curr_tensor.tensor_comparison.stats | |||
| curr_tensor_stats = TensorUtils.get_statistics_from_tensor(curr_tensor.value) | |||
| prev_tensor_stats = TensorUtils.get_statistics_from_tensor(prev_tensor.value) | |||
| stats_info['curr_step_statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=curr_tensor_stats) | |||
| stats_info['prev_step_statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=prev_tensor_stats) | |||
| stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats) | |||
| return stats_info | |||
| def get_tensor_statistics(self, tensor_name, node_type): | |||
| """ | |||
| Get Tensor statistics. | |||
| @@ -378,6 +387,6 @@ class TensorHandler(StreamHandlerBase): | |||
| """ | |||
| res = {} | |||
| tensor = self._get_tensor(tensor_name, node_type) | |||
| if tensor: | |||
| if tensor and not tensor.empty: | |||
| res = tensor.get_tensor_statistics() | |||
| return res | |||
| @@ -15,13 +15,14 @@ | |||
| """This module is aimed to provide with tensor detail info.""" | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import Streams | |||
| from mindinsight.debugger.common.utils import Streams, create_view_event_from_tensor_basic_info | |||
| class TensorDetailInfo: | |||
| """Manage tensor detail information.""" | |||
| def __init__(self, cache): | |||
| self._put_command = cache.put_command | |||
| self._tensor_stream = cache.get_stream_handler(Streams.TENSOR) | |||
| self._graph_stream = cache.get_stream_handler(Streams.GRAPH) | |||
| self._hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| @@ -47,7 +48,7 @@ class TensorDetailInfo: | |||
| Get the graph related to specific tensor. | |||
| Args: | |||
| tensor_name (str): The name of tensor. Format like {node_name}:{slot}. | |||
| tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. | |||
| graph_name (str): The graph name. | |||
| Returns: | |||
| @@ -70,12 +71,16 @@ class TensorDetailInfo: | |||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) | |||
| graph = self._graph_stream.get_tensor_graph(tensor_name, graph_name) | |||
| # add watchpoint hits info and statistics info for each tensor in tensor graph. | |||
| # record missing tensor basic info | |||
| nodes = graph.get('graph', {}).get('nodes', []) | |||
| missing_tensors = [] | |||
| for node in nodes: | |||
| node['graph_name'] = graph_name | |||
| for slot_info in node.get('slots', []): | |||
| self._add_watchpoint_hit_info(slot_info, node) | |||
| self._add_statistic_info(slot_info, node) | |||
| self._add_statistic_info(slot_info, node, missing_tensors) | |||
| # query missing tensor values from client | |||
| self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) | |||
| return graph | |||
| def _add_watchpoint_hit_info(self, slot_info, node): | |||
| @@ -89,17 +94,38 @@ class TensorDetailInfo: | |||
| tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) | |||
| slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name)) | |||
| def _add_statistic_info(self, slot_info, node): | |||
| def _add_statistic_info(self, slot_info, node, missing_tensors): | |||
| """ | |||
| Get the watchpoint that the tensor hit. | |||
| Args: | |||
| slot_info (dict): Slot object. | |||
| node (dict): Node object. | |||
| missing_tensors (list[TensorBasicInfo]): List of missing tensor infos. | |||
| """ | |||
| tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')]) | |||
| node_type = node.get('type') | |||
| slot_info['statistics'] = self._tensor_stream.get_tensor_statistics(tensor_name, node_type) | |||
| if not slot_info.get('statistics'): | |||
| log.debug("Get missing tensor basic infos for %s", tensor_name) | |||
| cur_missing_tensors = self._tensor_stream.get_missing_tensor_info(tensor_name, node_type) | |||
| missing_tensors.extend(cur_missing_tensors) | |||
| def _ask_for_missing_tensor_value(self, missing_tensors, tensor_name, graph_name): | |||
| """ | |||
| Send view command to client to query for missing tensor values. | |||
| Args: | |||
| missing_tensors (list[TensorBasicInfo]): List of missing tensor basic infos. | |||
| tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. | |||
| graph_name (str): The graph name. | |||
| """ | |||
| if not missing_tensors: | |||
| return | |||
| log.debug("Ask for tensor value for: %s", missing_tensors) | |||
| view_cmd = create_view_event_from_tensor_basic_info(missing_tensors) | |||
| self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name}) | |||
| log.debug("Send view cmd for tensor-graphs.") | |||
| def get_tensor_watch_points(self, tensor_name, graph_name): | |||
| """ | |||
| @@ -0,0 +1,273 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """This module is aimed to deal with controlling commands.""" | |||
| import enum | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \ | |||
| DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| @enum.unique | |||
| class ControlTypeEnum(enum.Enum): | |||
| """Control Type.""" | |||
| CONTINUE = 'continue' # continue to run training | |||
| PAUSE = 'pause' # suspend training | |||
| TERMINATE = 'terminate' # terminate training | |||
| class TrainingControlOperator: | |||
| """Control training operator.""" | |||
| # max step number should be less than int32 | |||
| _MAX_STEP_NUM = 2 ** 31 - 1 | |||
| def __init__(self, cache_store): | |||
| self._cache_store = cache_store | |||
| self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | |||
| @staticmethod | |||
| def validate_mode(mode): | |||
| """Validate mode.""" | |||
| enum_members = [item.value for item in ControlTypeEnum] | |||
| if mode not in enum_members: | |||
| log.error("Invalid control mode %s", mode) | |||
| raise DebuggerParamValueError("Invalid control mode.") | |||
| def control(self, mode, params): | |||
| """ | |||
| Control the training process. | |||
| Args: | |||
| mode (str): Acceptable control command, including `continue`, | |||
| `pause` and `terminate`. | |||
| params (dict): The control params. | |||
| - level (str): The control granularity, `node` level or `step` level. | |||
| Default: `step`. | |||
| - steps (int): Specify the steps that training should run. | |||
| Used when `level` is `step`. | |||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||
| - graph_name (str): The graph name. | |||
| Returns: | |||
| dict, the response. | |||
| """ | |||
| if mode == ControlTypeEnum.CONTINUE.value: | |||
| reply = self.continue_training(params) | |||
| else: | |||
| mode_mapping = { | |||
| ControlTypeEnum.PAUSE.value: self.pause_training, | |||
| ControlTypeEnum.TERMINATE.value: self.terminate_training | |||
| } | |||
| reply = mode_mapping.get(mode)() | |||
| return reply | |||
| def continue_training(self, params): | |||
| """ | |||
| Send RunCMD to MindSpore. | |||
| Args: | |||
| params (dict): The control params. | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| self._cache_store.put_data(metadata_stream.get()) | |||
| log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) | |||
| raise DebuggerContinueError( | |||
| "MindSpore is not ready to run or is running currently." | |||
| ) | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| try: | |||
| self._validate_continue_params(params) | |||
| event = self._construct_run_event(params) | |||
| self._send_watchpoints() | |||
| self._cache_store.put_command(event) | |||
| except MindInsightException as err: | |||
| log.error("Failed to send run event.") | |||
| log.exception(err) | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| raise DebuggerContinueError("Failed to send run command.") | |||
| else: | |||
| metadata_stream.enable_recheck = False | |||
| log.debug("Send the RunCMD to command queue.") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _validate_continue_params(self, params): | |||
| """ | |||
| Validate continue params. | |||
| Args: | |||
| params (dict): The control params. | |||
| - level (str): The control granularity, `node`, `step` or `recheck` level. | |||
| Default: `step`. | |||
| - steps (int): Specify the steps that training should run. | |||
| Used when `level` is `step`. | |||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||
| - graph_name (str): The graph name. | |||
| Raises: | |||
| DebuggerParamValueError: Params are invalid. | |||
| DebuggerStepNumError: Step number are invalid. | |||
| """ | |||
| # validate level | |||
| level = params.get('level', 'step') | |||
| if level not in [RunLevel.NODE.value, RunLevel.STEP.value, RunLevel.RECHECK.value]: | |||
| log.error("Invalid Value. `level` should be `step`, `node` or `recheck`. Got %s", level) | |||
| raise DebuggerParamValueError("level` should be `step`, `node` or `recheck`.") | |||
| # validate steps | |||
| step_num = params.get('steps', 1) | |||
| if not isinstance(step_num, int) or not (step_num == -1 or 0 < step_num <= self._MAX_STEP_NUM): | |||
| log.error("Invalid step value. Step number should be integer and in [1, 2^31 - 1] or -1.") | |||
| raise DebuggerStepNumError | |||
| # validate node name | |||
| if level == RunLevel.NODE.value: | |||
| node_name = params.get('name') | |||
| graph_name = params.get('graph_name') | |||
| self._validate_continue_node_name(node_name, graph_name) | |||
| def _validate_continue_node_name(self, node_name, graph_name): | |||
| """Validate if the node is a leaf node.""" | |||
| if not node_name: | |||
| return | |||
| node_type = self._graph_stream.get_node_type(node_name, graph_name) | |||
| if is_scope_type(node_type): | |||
| log.error("Scope type node has no tensor history.") | |||
| raise DebuggerParamValueError("Invalid leaf node name.") | |||
| def _construct_run_event(self, params): | |||
| """ | |||
| Construct run cmd from input control params. | |||
| Args: | |||
| params (dict): The control params. | |||
| - level (str): The control granularity, `node`, `step` or `recheck` level. | |||
| Default: `step`. | |||
| - steps (int): Specify the steps that training should run. | |||
| Used when `level` is `step`. | |||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||
| - graph_name (str): The graph name. | |||
| Returns: | |||
| EventReply, control event with run command. | |||
| """ | |||
| level = params.get('level', 'step') | |||
| # construct run command events | |||
| event = get_ack_reply() | |||
| if level == 'step': | |||
| steps = params.get('steps', 1) | |||
| run_cmd = RunCMD(run_level='step', run_steps=steps) | |||
| elif level == 'node': | |||
| name = params.get('name', '') | |||
| graph_name = params.get('graph_name') | |||
| if name: | |||
| name = self._cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name) | |||
| run_cmd = RunCMD(run_level='node', node_name=name) | |||
| else: | |||
| run_cmd = RunCMD(run_level='recheck') | |||
| event.run_cmd.CopyFrom(run_cmd) | |||
| log.debug("Construct run event. %s", event) | |||
| return event | |||
| def _send_watchpoints(self): | |||
| """Send watchpoints to client.""" | |||
| set_commands = self._watchpoint_stream.get_pending_commands(self._graph_stream) | |||
| if not set_commands: | |||
| return | |||
| for set_cmd in set_commands: | |||
| event = get_ack_reply() | |||
| event.set_cmd.CopyFrom(set_cmd) | |||
| self._cache_store.put_command(event) | |||
| log.debug("Send SetCMD to MindSpore. %s", event) | |||
| self._watchpoint_stream.sync_set_cmd(set_commands) | |||
| def pause_training(self): | |||
| """ | |||
| Pause the training. | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.state != ServerStatus.RUNNING.value: | |||
| self._cache_store.put_data(metadata_stream.get()) | |||
| log.error("The MindSpore is not running.") | |||
| raise DebuggerPauseError("The MindSpore is not running.") | |||
| metadata_stream.state = 'waiting' | |||
| event = get_ack_reply() | |||
| event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) | |||
| self._cache_store.put_command(event) | |||
| metadata_stream.enable_recheck = False | |||
| log.debug("Send the Pause command") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def terminate_training(self): | |||
| """ | |||
| Terminate the training. | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| metadata_stream.state = 'pending' | |||
| self._cache_store.clean_data() | |||
| self._cache_store.clean_command() | |||
| event = get_ack_reply() | |||
| event.exit = True | |||
| self._cache_store.put_command(event) | |||
| metadata_stream.enable_recheck = False | |||
| log.debug("Send the ExitCMD.") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def recheck(self): | |||
| """ | |||
| Recheck all watchpoints. | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| # validate backend status is able to recheck watchpoint | |||
| if not metadata_stream.enable_recheck: | |||
| log.error("Recheck is not available.") | |||
| raise DebuggerRecheckError("Recheck is not available.") | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| metadata_stream.enable_recheck = False | |||
| # send updated watchpoint and recheck command | |||
| try: | |||
| event = self._construct_run_event({'level': 'recheck'}) | |||
| self._send_watchpoints() | |||
| self._cache_store.put_command(event) | |||
| except MindInsightException as err: | |||
| log.error("Failed to send recheck event.") | |||
| log.exception(err) | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| metadata_stream.enable_recheck = True | |||
| raise DebuggerContinueError("Failed to send recheck command.") | |||
| else: | |||
| log.debug("Send the recheck to command queue.") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| @@ -316,6 +316,8 @@ class TensorUtils: | |||
| Returns: | |||
| dict, overall statistics. | |||
| """ | |||
| if not overall_stats: | |||
| return {} | |||
| res = { | |||
| "overall_max": float(overall_stats.max), | |||
| "overall_min": float(overall_stats.min), | |||
| @@ -17,13 +17,6 @@ | |||
| ] | |||
| ], | |||
| "curr_step_statistics": { | |||
| "max": 6.0, | |||
| "min": 1.0, | |||
| "avg": 3.5, | |||
| "count": 6, | |||
| "nan_count": 0, | |||
| "neg_inf_count": 0, | |||
| "pos_inf_count": 0, | |||
| "overall_max": 6.0, | |||
| "overall_min": 1.0, | |||
| "overall_avg": 3.5, | |||
| @@ -36,13 +29,6 @@ | |||
| "overall_pos_zero_count": 6.0 | |||
| }, | |||
| "prev_step_statistics": { | |||
| "max": 6.0, | |||
| "min": 1.0, | |||
| "avg": 3.5, | |||
| "count": 6, | |||
| "nan_count": 0, | |||
| "neg_inf_count": 0, | |||
| "pos_inf_count": 0, | |||
| "overall_max": 6.0, | |||
| "overall_min": 1.0, | |||
| "overall_avg": 3.5, | |||
| @@ -55,13 +41,6 @@ | |||
| "overall_pos_zero_count": 6.0 | |||
| }, | |||
| "statistics": { | |||
| "max": 0.0, | |||
| "min": 0.0, | |||
| "avg": 0.0, | |||
| "count": 6, | |||
| "nan_count": 0, | |||
| "neg_inf_count": 0, | |||
| "pos_inf_count": 0, | |||
| "overall_max": 0.0, | |||
| "overall_min": 0.0, | |||
| "overall_avg": 0.0, | |||
| @@ -1 +1,29 @@ | |||
| {"tensor_value": {"full_name": "Default/TransData-op99:0", "step": 1, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "statistics": {"max": 6.0, "min": 5.0, "avg": 5.5, "count": 2, "nan_count": 0, "neg_inf_count": 0, "pos_inf_count": 0, "overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, "overall_count": 6, "overall_nan_count": 0, "overall_neg_inf_count": 0, "overall_pos_inf_count": 0, "overall_zero_count": 0.0, "overall_neg_zero_count": 0.0, "overall_pos_zero_count": 6.0}, "value": [5.0, 6.0], "name": "Default/TransData-op99:0"}} | |||
| { | |||
| "tensor_value": { | |||
| "full_name": "Default/TransData-op99:0", | |||
| "step": 1, | |||
| "dtype": "DT_FLOAT32", | |||
| "shape": [ | |||
| 2, | |||
| 3 | |||
| ], | |||
| "has_prev_step": false, | |||
| "statistics": { | |||
| "overall_max": 6.0, | |||
| "overall_min": 1.0, | |||
| "overall_avg": 3.5, | |||
| "overall_count": 6, | |||
| "overall_nan_count": 0, | |||
| "overall_neg_inf_count": 0, | |||
| "overall_pos_inf_count": 0, | |||
| "overall_zero_count": 0.0, | |||
| "overall_neg_zero_count": 0.0, | |||
| "overall_pos_zero_count": 6.0 | |||
| }, | |||
| "value": [ | |||
| 5.0, | |||
| 6.0 | |||
| ], | |||
| "name": "Default/TransData-op99:0" | |||
| } | |||
| } | |||
| @@ -217,5 +217,5 @@ class MockDebuggerClientThread: | |||
| return self._debugger_client_thread | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| self._debugger_client_thread.join(timeout=3) | |||
| self._debugger_client_thread.join(timeout=2) | |||
| self._debugger_client.flag = False | |||
| @@ -17,6 +17,7 @@ from mindinsight.debugger.common.utils import ServerStatus | |||
| from mindinsight.debugger.stream_handler.metadata_handler import MetadataHandler | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import Metadata | |||
| class TestMetadataHandler: | |||
| """test class for MetadataHandler""" | |||
| def setup_method(self): | |||
| @@ -40,7 +40,7 @@ class TestTensorHandler: | |||
| def test_get_tensor_value_by_name_none(self): | |||
| """Test get_tensor_value_by_name.""" | |||
| res = self.tensor_handler.get_tensor_value_by_name('tensor_name', True) | |||
| res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', True) | |||
| assert res is None | |||
| @mock.patch.object(log, "error") | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test for debugger stream operator.""" | |||
| @@ -0,0 +1,66 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Function: | |||
| Test debugger training control operator. | |||
| Usage: | |||
| pytest tests/ut/debugger/stream_operator/test_training_control_operator.py | |||
| """ | |||
| from unittest import mock | |||
| import pytest | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||
| from mindinsight.debugger.stream_handler import GraphHandler, MetadataHandler | |||
| from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator | |||
| class TestTrainingControlOperator: | |||
| """Test debugger server.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Initialize for test class.""" | |||
| cls._server = None | |||
| def setup_method(self): | |||
| """Prepare debugger server object.""" | |||
| cache_store = DebuggerCache() | |||
| cache_store.initialize() | |||
| self._server = TrainingControlOperator(cache_store) | |||
| @mock.patch.object(GraphHandler, 'get_node_type') | |||
| def test_validate_leaf_name(self, *args): | |||
| """Test validate leaf name.""" | |||
| args[0].return_value = 'name_scope' | |||
| with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): | |||
| self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') | |||
| @pytest.mark.parametrize('mode, cur_state, state', [ | |||
| ('continue', 'waiting', 'running'), | |||
| ('pause', 'running', 'waiting'), | |||
| ('terminate', 'waiting', 'pending')]) | |||
| def test_control(self, mode, cur_state, state): | |||
| """Test control request.""" | |||
| with mock.patch.object(MetadataHandler, 'state', cur_state): | |||
| res = self._server.control(mode=mode, params={}) | |||
| assert res == {'metadata': {'enable_recheck': False, 'state': state}} | |||
| def test_construct_run_event(self): | |||
| """Test construct run event.""" | |||
| res = self._server._construct_run_event({'level': 'node'}) | |||
| assert res.run_cmd == RunCMD(run_level='node', node_name='') | |||
| @@ -68,7 +68,7 @@ class MockDataGenerator: | |||
| view_event = get_ack_reply() | |||
| ms_tensor = view_event.view_cmd.tensors.add() | |||
| ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0' | |||
| event = {'view_cmd': view_event, 'node_name': 'mock_node_name'} | |||
| event = {'view_cmd': view_event, 'node_name': 'mock_node_name', 'graph_name': 'mock_graph_name'} | |||
| return event | |||
| @staticmethod | |||
| @@ -180,10 +180,10 @@ class TestDebuggerGrpcServer: | |||
| def test_deal_with_old_command_with_view_cmd(self, *args): | |||
| """Test deal with view command.""" | |||
| cmd = MockDataGenerator.get_view_cmd() | |||
| args[1].return_value = ('0', cmd) | |||
| args[1].return_value = ('0', cmd.copy()) | |||
| res = self._server._deal_with_old_command() | |||
| assert res == cmd.get('view_cmd') | |||
| expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True} | |||
| assert res == cmd.pop('view_cmd') | |||
| expect_received_view_cmd = {'node_info': cmd, 'wait_for_tensor': True} | |||
| assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd | |||
| @mock.patch.object(DebuggerCache, 'get_command') | |||
| @@ -201,7 +201,7 @@ class TestDebuggerGrpcServer: | |||
| """Test wait for run command.""" | |||
| pause_cmd = MockDataGenerator.get_run_cmd(steps=0) | |||
| empty_view_cmd = MockDataGenerator.get_view_cmd() | |||
| empty_view_cmd.pop('node_name') | |||
| empty_view_cmd.pop('view_cmd') | |||
| run_cmd = MockDataGenerator.get_run_cmd(steps=2) | |||
| args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)] | |||
| setattr(self._server, '_status', ServerStatus.WAITING) | |||
| @@ -32,7 +32,6 @@ from mindinsight.debugger.common.utils import Streams | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||
| from mindinsight.debugger.debugger_server import grpc_server_base | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||
| from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | |||
| TensorHandler | |||
| from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | |||
| @@ -154,13 +153,6 @@ class TestDebuggerServer: | |||
| res = self._server.retrieve_tensor_history('mock_node_name') | |||
| compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json') | |||
| @mock.patch.object(GraphHandler, 'get_node_type') | |||
| def test_validate_leaf_name(self, *args): | |||
| """Test validate leaf name.""" | |||
| args[0].return_value = 'name_scope' | |||
| with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): | |||
| self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') | |||
| @mock.patch.object(TensorHandler, 'get') | |||
| @mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name') | |||
| def test_retrieve_tensor_value(self, *args): | |||
| @@ -187,7 +179,6 @@ class TestDebuggerServer: | |||
| res = self._server._retrieve_watchpoint({'watch_point_id': 1}) | |||
| assert res == mock_watchpoint | |||
| @mock.patch.object(DebuggerServer, '_validate_continue_node_name') | |||
| @mock.patch.object(DebuggerServer, '_get_tensor_history') | |||
| @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}}) | |||
| def test_retrieve_watchpoint_hit(self, *args): | |||
| @@ -238,18 +229,3 @@ class TestDebuggerServer: | |||
| args[0].return_value = None | |||
| res = self._server.delete_watchpoint(1) | |||
| assert res == {'metadata': {'enable_recheck': True, 'state': 'waiting'}} | |||
| @pytest.mark.parametrize('mode, cur_state, state', [ | |||
| ('continue', 'waiting', 'running'), | |||
| ('pause', 'running', 'waiting'), | |||
| ('terminate', 'waiting', 'pending')]) | |||
| def test_control(self, mode, cur_state, state): | |||
| """Test control request.""" | |||
| with mock.patch.object(MetadataHandler, 'state', cur_state): | |||
| res = self._server.control({'mode': mode}) | |||
| assert res == {'metadata': {'enable_recheck': False, 'state': state}} | |||
| def test_construct_run_event(self): | |||
| """Test construct run event.""" | |||
| res = self._server._construct_run_event({'level': 'node'}) | |||
| assert res.run_cmd == RunCMD(run_level='node', node_name='') | |||