From: @jiang-shuqiang Reviewed-by: @wenkai_dist Signed-off-by:tags/v1.1.0
| @@ -28,6 +28,7 @@ class DebuggerErrors(DebuggerErrorCodes): | |||
| PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK | |||
| PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK | |||
| STEP_NUM_ERROR = 2 | _PARAM_ERROR_MASK | |||
| DEBUGGER_CONDITION_UNAVAILABLE_ERROR = 3 | _PARAM_ERROR_MASK | |||
| NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR | |||
| GRAPH_NOT_EXIST_ERROR = 1 | _DEBUGGER_GRAPH_ERROR | |||
| @@ -41,6 +42,7 @@ class DebuggerErrors(DebuggerErrorCodes): | |||
| RECHECK_ERROR = 6 | _DEBUGGER_RUNNING_ERROR | |||
| TENSOR_GRAPH_ERROR = 7 | _DEBUGGER_RUNNING_ERROR | |||
| TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR | |||
| SET_RECOMMEND_WATCHPOINT_ERROR = 9 | _DEBUGGER_RUNNING_ERROR | |||
| @unique | |||
| @@ -48,6 +50,7 @@ class DebuggerErrorMsg(Enum): | |||
| """Debugger error messages.""" | |||
| PARAM_TYPE_ERROR = "TypeError. {}" | |||
| PARAM_VALUE_ERROR = "ValueError. {}" | |||
| DEBUGGER_CONDITION_UNAVAILABLE_ERROR = "Condition is unavailable. {}" | |||
| GRAPH_NOT_EXIST_ERROR = "The graph does not exist." | |||
| @@ -59,3 +62,4 @@ class DebuggerErrorMsg(Enum): | |||
| RECHECK_ERROR = "Recheck failed. {}" | |||
| TENSOR_GRAPH_ERROR = "Get tensor graphs failed." | |||
| TENSOR_HIT_ERROR = "Get tensor hits failed." | |||
| SET_RECOMMEND_WATCHPOINT_ERROR = "Set Recommend Watchpoints failed." | |||
| @@ -168,3 +168,25 @@ class DebuggerTensorHitError(MindInsightException): | |||
| message=DebuggerErrorMsg.TENSOR_HIT_ERROR.value, | |||
| http_code=400 | |||
| ) | |||
| class DebuggerSetRecommendWatchpointsError(MindInsightException): | |||
| """The set recommend watchpoints error in debugger module.""" | |||
| def __init__(self): | |||
| super(DebuggerSetRecommendWatchpointsError, self).__init__( | |||
| error=DebuggerErrors.SET_RECOMMEND_WATCHPOINT_ERROR, | |||
| message=DebuggerErrorMsg.SET_RECOMMEND_WATCHPOINT_ERROR.value, | |||
| http_code=400 | |||
| ) | |||
| class DebuggerConditionUnavailableError(MindInsightException): | |||
| """The condition unavailable error in debugger module.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerConditionUnavailableError, self).__init__( | |||
| error=DebuggerErrors.DEBUGGER_CONDITION_UNAVAILABLE_ERROR, | |||
| message=DebuggerErrorMsg.DEBUGGER_CONDITION_UNAVAILABLE_ERROR.value.format(msg), | |||
| http_code=400 | |||
| ) | |||
| @@ -503,7 +503,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| } | |||
| hit_params = {} | |||
| for param in watchpoint_hit_proto.watch_condition.params: | |||
| if param.actual_value: | |||
| if param.actual_value is not None: | |||
| hit_params[param.name] = param.actual_value | |||
| for i, param in enumerate(watchpoint_hit['watchpoint'].condition['params']): | |||
| name = param['name'] | |||
| @@ -511,7 +511,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = hit_params[name] | |||
| else: | |||
| watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = None | |||
| if watchpoint_hit_proto.error_code: | |||
| if watchpoint_hit_proto.error_code is not None: | |||
| watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code | |||
| watchpoint_hits.append(watchpoint_hit) | |||
| self._received_hit = watchpoint_hits | |||
| @@ -28,7 +28,7 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.datavisual.utils.tools import to_float | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerParamTypeError, DebuggerCompareTensorError, DebuggerTensorGraphError, \ | |||
| DebuggerTensorHitError, MindInsightException | |||
| DebuggerTensorHitError, DebuggerSetRecommendWatchpointsError, MindInsightException | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import ServerStatus, \ | |||
| create_view_event_from_tensor_basic_info, Streams | |||
| @@ -81,10 +81,13 @@ class DebuggerServer: | |||
| log.error("Bool param should be given for set_recommended") | |||
| raise DebuggerParamValueError("Bool param should be given.") | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| if metadata_stream.recommendation_confirmed: | |||
| log.error("User has confirmed setting recommended watchpoints") | |||
| raise DebuggerSetRecommendWatchpointsError() | |||
| condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) | |||
| log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) | |||
| res = metadata_stream.get(['state', 'enable_recheck']) | |||
| if set_recommended and not metadata_stream.recommendation_confirmed: | |||
| if set_recommended: | |||
| res['id'] = self._add_recommended_watchpoints(condition_context) | |||
| metadata_stream.recommendation_confirmed = True | |||
| return res | |||
| @@ -25,6 +25,10 @@ from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointH | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| RANGE_START = 'range_start_inclusive' | |||
| RANGE_END = 'range_end_inclusive' | |||
| class WatchpointHandler(StreamHandlerBase): | |||
| """Watchpoint Handler.""" | |||
| @@ -446,7 +450,7 @@ class WatchpointHitHandler(StreamHandlerBase): | |||
| watch_points.append(watchpoint) | |||
| if watch_points: | |||
| watch_points.sort(key=_watchpoint_id) | |||
| watch_points.sort(key=lambda watch_point: watch_point.get('id')) | |||
| res = { | |||
| 'slot': slot, | |||
| 'watch_points': watch_points | |||
| @@ -540,6 +544,7 @@ def validate_watch_condition_params(condition_mgr, watch_condition): | |||
| check_param_num = 0 | |||
| support_params = set() | |||
| defined_support_params = set() | |||
| range_param = {RANGE_START: None, RANGE_END: None} | |||
| for param in params: | |||
| if len(param) > 2: | |||
| log.error("Invalid param keys for condition: %s", condition_id) | |||
| @@ -573,6 +578,9 @@ def validate_watch_condition_params(condition_mgr, watch_condition): | |||
| else: | |||
| support_params.add(condition_param.name) | |||
| if condition_param_name in range_param: | |||
| range_param[condition_param_name] = param.get("value") | |||
| if check_param_num > 1: | |||
| log.error("Multiple check params for condition: %s", condition_id) | |||
| raise DebuggerParamValueError("Multiple check params.") | |||
| @@ -581,6 +589,12 @@ def validate_watch_condition_params(condition_mgr, watch_condition): | |||
| log.error("Invalid support params for condition: %s", condition_id) | |||
| raise DebuggerParamValueError("Invalid support params.") | |||
| if range_param.get(RANGE_START) is not None and \ | |||
| range_param.get(RANGE_END) is not None and range_param.get(RANGE_START) > \ | |||
| range_param.get(RANGE_END): | |||
| log.error("Invalid support params for condition: %s", condition_id) | |||
| raise DebuggerParamValueError("Invalid support params.") | |||
| def set_default_param(condition_mgr, watch_condition): | |||
| """ | |||
| @@ -615,10 +629,6 @@ def set_default_param(condition_mgr, watch_condition): | |||
| return watch_condition | |||
| def _watchpoint_id(watchpoint): | |||
| return watchpoint.get('id') | |||
| def _get_error_list(error_code): | |||
| """ | |||
| Get error list. | |||
| @@ -17,7 +17,7 @@ from queue import Queue | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | |||
| DebuggerDeleteWatchPointError | |||
| DebuggerDeleteWatchPointError, DebuggerConditionUnavailableError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import ServerStatus, \ | |||
| Streams, is_cst_type | |||
| @@ -81,7 +81,7 @@ class WatchpointOperator: | |||
| condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) | |||
| if not condition.is_available(condition_context): | |||
| log.error("Failed to create watchpoint as the condition is not available.") | |||
| raise DebuggerCreateWatchPointError( | |||
| raise DebuggerConditionUnavailableError( | |||
| "Failed to create watchpoint as the condition is not available.") | |||
| if condition.supported_target_type in [TargetTypeEnum.ACTIVATION, TargetTypeEnum.GRADIENT, | |||
| @@ -113,7 +113,7 @@ class Command(BaseCommand): | |||
| def run(self, args): | |||
| """ | |||
| Execute for start command. | |||
| Execute for parse_summary command. | |||
| Args: | |||
| args (Namespace): Parsed arguments to hold customized parameters. | |||