From a2bab88378d3fd8fcc18ba4be789afc02608c73b Mon Sep 17 00:00:00 2001 From: yelihua Date: Wed, 9 Dec 2020 20:47:26 +0800 Subject: [PATCH] enable recheck at step 0 --- mindinsight/debugger/stream_cache/tensor.py | 7 ++--- .../stream_handler/metadata_handler.py | 2 +- .../stream_handler/watchpoint_handler.py | 29 +++++++++++++------ .../training_control_operator.py | 5 ++++ tests/st/func/debugger/mock_ms_client.py | 2 ++ 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/mindinsight/debugger/stream_cache/tensor.py b/mindinsight/debugger/stream_cache/tensor.py index be513aad..15af4fa2 100644 --- a/mindinsight/debugger/stream_cache/tensor.py +++ b/mindinsight/debugger/stream_cache/tensor.py @@ -247,8 +247,7 @@ class OpTensor(BaseTensor): class ConstTensor(BaseTensor): """Tensor data structure for Const Node.""" - STRING_TYPE = 'DT_STRING' - BOOL_TYPE = 'DT_BOOL' + _STRING_TYPE = 'DT_STRING' def __init__(self, const_proto): # the type of const_proto is NamedValueProto @@ -298,7 +297,7 @@ class ConstTensor(BaseTensor): if field_name != 'dtype': tensor_value = field_value break - if tensor_value and self.dtype != self.STRING_TYPE: + if tensor_value and self.dtype != self._STRING_TYPE: tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(self.dtype)) return tensor_value @@ -323,7 +322,7 @@ class ConstTensor(BaseTensor): Returns: dict, overall statistics. """ - if self.empty or self.dtype == self.STRING_TYPE: + if self.empty or self.dtype == self._STRING_TYPE: return {} stats = TensorUtils.get_statistics_from_tensor(self.value) statistics = TensorUtils.get_overall_statistic_dict(stats) diff --git a/mindinsight/debugger/stream_handler/metadata_handler.py b/mindinsight/debugger/stream_handler/metadata_handler.py index 58785c8c..1f161aeb 100644 --- a/mindinsight/debugger/stream_handler/metadata_handler.py +++ b/mindinsight/debugger/stream_handler/metadata_handler.py @@ -109,7 +109,7 @@ class MetadataHandler(StreamHandlerBase): @property def enable_recheck(self): """The property of enable_recheck.""" - return self._enable_recheck and self._state == ServerStatus.WAITING and self._step > 0 + return self._enable_recheck and self._state == ServerStatus.WAITING @enable_recheck.setter def enable_recheck(self, value): diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py index 6d7f25ce..9ff6cf4f 100644 --- a/mindinsight/debugger/stream_handler/watchpoint_handler.py +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -557,15 +557,7 @@ def validate_watch_condition_params(condition_mgr, watch_condition): raise DebuggerParamValueError("Invalid name of parameter.") condition_param = condition.get_parameter_definition(condition_param_name) - if condition_param.type.name in (ValueTypeEnum.FLOAT64.name, ValueTypeEnum.INT64.name) \ - and not isinstance(param.get("value"), (float, int)): - log.error("Number param should be given for condition: %s", condition_id) - raise DebuggerParamValueError("Number param should be given.") - - if condition_param.type.name == ValueTypeEnum.BOOL.name \ - and not isinstance(param.get("value"), bool): - log.error("Bool param should be given for condition: %s", condition_id) - raise DebuggerParamValueError("Bool param should be given.") + validate_param_type(condition_id, condition_param, param) if not condition_param.is_valid(param.get("value")): log.error("Param %s out of range for condition: %s", condition_param_name, condition_id) @@ -596,6 +588,25 @@ def validate_watch_condition_params(condition_mgr, watch_condition): raise DebuggerParamValueError("Invalid support params.") +def validate_param_type(condition_id, condition_param, param): + """ + Validate parameter type. + + Args: + condition_id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING. + condition_param (ConditionParameter): Condition Parameter object. + param (dict): Condition parameter value. + """ + if condition_param.type.name in (ValueTypeEnum.FLOAT64.name, ValueTypeEnum.INT64.name) \ + and not isinstance(param.get("value"), (float, int)): + log.error("Number param should be given for condition: %s", condition_id) + raise DebuggerParamValueError("Number param should be given.") + if condition_param.type.name == ValueTypeEnum.BOOL.name \ + and not isinstance(param.get("value"), bool): + log.error("Bool param should be given for condition: %s", condition_id) + raise DebuggerParamValueError("Bool param should be given.") + + def set_default_param(condition_mgr, watch_condition): """ Set default param. diff --git a/mindinsight/debugger/stream_operator/training_control_operator.py b/mindinsight/debugger/stream_operator/training_control_operator.py index 913b185c..f19c9c3d 100644 --- a/mindinsight/debugger/stream_operator/training_control_operator.py +++ b/mindinsight/debugger/stream_operator/training_control_operator.py @@ -99,7 +99,12 @@ class TrainingControlOperator: try: self._validate_continue_params(params) event = self._construct_run_event(params) + # whether need to send recheck before continue, especially for initialization watchpoint + recheck_flag = bool(self._metadata_stream.step == 0 and self._watchpoint_stream.is_recheckable()) self._send_watchpoints() + if recheck_flag: + self._cache_store.put_command(self._construct_run_event({'level': 'recheck'})) + log.info("Send recheck command for initialization watchpoints before continue command.") self._cache_store.put_command(event) except MindInsightException as err: log.error("Failed to send run event.") diff --git a/tests/st/func/debugger/mock_ms_client.py b/tests/st/func/debugger/mock_ms_client.py index b558b359..c698a9a7 100644 --- a/tests/st/func/debugger/mock_ms_client.py +++ b/tests/st/func/debugger/mock_ms_client.py @@ -14,6 +14,7 @@ # ============================================================================ """Mocked MindSpore debugger client.""" from threading import Thread +from time import sleep import grpc import numpy as np @@ -77,6 +78,7 @@ class MockDebuggerClient: wait_flag = True while self.flag and wait_flag: if self._step > total_steps: + sleep(0.5) self.send_metadata_cmd(training_done=True) return wait_flag = self._wait_cmd()