diff --git a/mindinsight/debugger/common/exceptions/error_code.py b/mindinsight/debugger/common/exceptions/error_code.py index 2528ae39..e87460b4 100644 --- a/mindinsight/debugger/common/exceptions/error_code.py +++ b/mindinsight/debugger/common/exceptions/error_code.py @@ -28,6 +28,7 @@ class DebuggerErrors(DebuggerErrorCodes): PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK STEP_NUM_ERROR = 2 | _PARAM_ERROR_MASK + DEBUGGER_CONDITION_UNAVAILABLE_ERROR = 3 | _PARAM_ERROR_MASK NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR GRAPH_NOT_EXIST_ERROR = 1 | _DEBUGGER_GRAPH_ERROR @@ -41,6 +42,7 @@ class DebuggerErrors(DebuggerErrorCodes): RECHECK_ERROR = 6 | _DEBUGGER_RUNNING_ERROR TENSOR_GRAPH_ERROR = 7 | _DEBUGGER_RUNNING_ERROR TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR + SET_RECOMMEND_WATCHPOINT_ERROR = 9 | _DEBUGGER_RUNNING_ERROR @unique @@ -48,6 +50,7 @@ class DebuggerErrorMsg(Enum): """Debugger error messages.""" PARAM_TYPE_ERROR = "TypeError. {}" PARAM_VALUE_ERROR = "ValueError. {}" + DEBUGGER_CONDITION_UNAVAILABLE_ERROR = "Condition is unavailable. {}" GRAPH_NOT_EXIST_ERROR = "The graph does not exist." @@ -59,3 +62,4 @@ class DebuggerErrorMsg(Enum): RECHECK_ERROR = "Recheck failed. {}" TENSOR_GRAPH_ERROR = "Get tensor graphs failed." TENSOR_HIT_ERROR = "Get tensor hits failed." + SET_RECOMMEND_WATCHPOINT_ERROR = "Set Recommend Watchpoints failed." diff --git a/mindinsight/debugger/common/exceptions/exceptions.py b/mindinsight/debugger/common/exceptions/exceptions.py index 0c27533e..060325fe 100644 --- a/mindinsight/debugger/common/exceptions/exceptions.py +++ b/mindinsight/debugger/common/exceptions/exceptions.py @@ -168,3 +168,25 @@ class DebuggerTensorHitError(MindInsightException): message=DebuggerErrorMsg.TENSOR_HIT_ERROR.value, http_code=400 ) + + +class DebuggerSetRecommendWatchpointsError(MindInsightException): + """The set recommend watchpoints error in debugger module.""" + + def __init__(self): + super(DebuggerSetRecommendWatchpointsError, self).__init__( + error=DebuggerErrors.SET_RECOMMEND_WATCHPOINT_ERROR, + message=DebuggerErrorMsg.SET_RECOMMEND_WATCHPOINT_ERROR.value, + http_code=400 + ) + + +class DebuggerConditionUnavailableError(MindInsightException): + """The condition unavailable error in debugger module.""" + + def __init__(self, msg): + super(DebuggerConditionUnavailableError, self).__init__( + error=DebuggerErrors.DEBUGGER_CONDITION_UNAVAILABLE_ERROR, + message=DebuggerErrorMsg.DEBUGGER_CONDITION_UNAVAILABLE_ERROR.value.format(msg), + http_code=400 + ) diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_grpc_server.py index 00055569..9122bbdf 100644 --- a/mindinsight/debugger/debugger_grpc_server.py +++ b/mindinsight/debugger/debugger_grpc_server.py @@ -495,7 +495,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): } hit_params = {} for param in watchpoint_hit_proto.watch_condition.params: - if param.actual_value: + if param.actual_value is not None: hit_params[param.name] = param.actual_value for i, param in enumerate(watchpoint_hit['watchpoint'].condition['params']): name = param['name'] @@ -503,7 +503,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = hit_params[name] else: watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = None - if watchpoint_hit_proto.error_code: + if watchpoint_hit_proto.error_code is not None: watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code watchpoint_hits.append(watchpoint_hit) self._received_hit = watchpoint_hits diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index e99a1287..13ae323f 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -28,7 +28,7 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.utils.tools import to_float from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ DebuggerParamTypeError, DebuggerCompareTensorError, DebuggerTensorGraphError, \ - DebuggerTensorHitError, MindInsightException + DebuggerTensorHitError, DebuggerSetRecommendWatchpointsError, MindInsightException from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import ServerStatus, \ create_view_event_from_tensor_basic_info, Streams @@ -81,10 +81,13 @@ class DebuggerServer: log.error("Bool param should be given for set_recommended") raise DebuggerParamValueError("Bool param should be given.") metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) + if metadata_stream.recommendation_confirmed: + log.error("User has confirmed setting recommended watchpoints") + raise DebuggerSetRecommendWatchpointsError() condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) res = metadata_stream.get(['state', 'enable_recheck']) - if set_recommended and not metadata_stream.recommendation_confirmed: + if set_recommended: res['id'] = self._add_recommended_watchpoints(condition_context) metadata_stream.recommendation_confirmed = True return res diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py index 3436f63c..6d7f25ce 100644 --- a/mindinsight/debugger/stream_handler/watchpoint_handler.py +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -25,6 +25,10 @@ from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointH from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase +RANGE_START = 'range_start_inclusive' +RANGE_END = 'range_end_inclusive' + + class WatchpointHandler(StreamHandlerBase): """Watchpoint Handler.""" @@ -446,7 +450,7 @@ class WatchpointHitHandler(StreamHandlerBase): watch_points.append(watchpoint) if watch_points: - watch_points.sort(key=_watchpoint_id) + watch_points.sort(key=lambda watch_point: watch_point.get('id')) res = { 'slot': slot, 'watch_points': watch_points @@ -540,6 +544,7 @@ def validate_watch_condition_params(condition_mgr, watch_condition): check_param_num = 0 support_params = set() defined_support_params = set() + range_param = {RANGE_START: None, RANGE_END: None} for param in params: if len(param) > 2: log.error("Invalid param keys for condition: %s", condition_id) @@ -573,6 +578,9 @@ def validate_watch_condition_params(condition_mgr, watch_condition): else: support_params.add(condition_param.name) + if condition_param_name in range_param: + range_param[condition_param_name] = param.get("value") + if check_param_num > 1: log.error("Multiple check params for condition: %s", condition_id) raise DebuggerParamValueError("Multiple check params.") @@ -581,6 +589,12 @@ def validate_watch_condition_params(condition_mgr, watch_condition): log.error("Invalid support params for condition: %s", condition_id) raise DebuggerParamValueError("Invalid support params.") + if range_param.get(RANGE_START) is not None and \ + range_param.get(RANGE_END) is not None and range_param.get(RANGE_START) > \ + range_param.get(RANGE_END): + log.error("Invalid support params for condition: %s", condition_id) + raise DebuggerParamValueError("Invalid support params.") + def set_default_param(condition_mgr, watch_condition): """ @@ -615,10 +629,6 @@ def set_default_param(condition_mgr, watch_condition): return watch_condition -def _watchpoint_id(watchpoint): - return watchpoint.get('id') - - def _get_error_list(error_code): """ Get error list. diff --git a/mindinsight/debugger/stream_operator/watchpoint_operator.py b/mindinsight/debugger/stream_operator/watchpoint_operator.py index 0aa8c7e4..ad4c1dcf 100644 --- a/mindinsight/debugger/stream_operator/watchpoint_operator.py +++ b/mindinsight/debugger/stream_operator/watchpoint_operator.py @@ -17,7 +17,7 @@ from queue import Queue from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ - DebuggerDeleteWatchPointError + DebuggerDeleteWatchPointError, DebuggerConditionUnavailableError from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import ServerStatus, \ Streams, is_cst_type @@ -81,7 +81,7 @@ class WatchpointOperator: condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) if not condition.is_available(condition_context): log.error("Failed to create watchpoint as the condition is not available.") - raise DebuggerCreateWatchPointError( + raise DebuggerConditionUnavailableError( "Failed to create watchpoint as the condition is not available.") if condition.supported_target_type in [TargetTypeEnum.ACTIVATION, TargetTypeEnum.GRADIENT, diff --git a/mindinsight/scripts/parse_summary.py b/mindinsight/scripts/parse_summary.py index e885c735..0380b8d1 100644 --- a/mindinsight/scripts/parse_summary.py +++ b/mindinsight/scripts/parse_summary.py @@ -113,7 +113,7 @@ class Command(BaseCommand): def run(self, args): """ - Execute for start command. + Execute for parse_summary command. Args: args (Namespace): Parsed arguments to hold customized parameters.