add "1", "0" for --enable-debugger add default node for all kind of watch conditionstags/v1.1.0
| @@ -20,11 +20,11 @@ from mindinsight.conf import settings | |||
| from mindinsight.utils.hook import BaseHook | |||
| def str2bool(string): | |||
| def enable_debugger_string(string): | |||
| """Convert str to bool""" | |||
| if string.lower() == 'false': | |||
| if string.lower() in ('false', '0'): | |||
| return False | |||
| if string.lower() == 'true': | |||
| if string.lower() in ('true', '1'): | |||
| return True | |||
| raise ValueError | |||
| @@ -83,11 +83,11 @@ class Hook(BaseHook): | |||
| """ | |||
| parser.add_argument( | |||
| '--enable-debugger', | |||
| type=str2bool, | |||
| type=enable_debugger_string, | |||
| action=EnableDebuggerAction, | |||
| default=False, | |||
| help=""" | |||
| Enable debugger or not. | |||
| Enable debugger or not. The value can be True/False/1/0 (case insensitive). | |||
| Default is False.""") | |||
| parser.add_argument( | |||
| @@ -504,9 +504,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| } | |||
| hit_params = {} | |||
| for param in watchpoint_hit_proto.watch_condition.params: | |||
| if param.actual_value is not None and param.name not in \ | |||
| (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value, | |||
| ParamNameEnum.RANGE_END_INCLUSIVE.value): | |||
| if param.name not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value, | |||
| ParamNameEnum.RANGE_END_INCLUSIVE.value) \ | |||
| and watchpoint_hit_proto.error_code == 0: | |||
| hit_params[param.name] = param.actual_value | |||
| for i, param in enumerate(watchpoint_hit['watchpoint'].condition['params']): | |||
| name = param['name'] | |||
| @@ -514,8 +514,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = hit_params[name] | |||
| else: | |||
| watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = None | |||
| if watchpoint_hit_proto.error_code is not None: | |||
| watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code | |||
| watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code | |||
| watchpoint_hits.append(watchpoint_hit) | |||
| self._received_hit = watchpoint_hits | |||
| reply = get_ack_reply() | |||
| @@ -21,7 +21,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import ServerStatus, \ | |||
| Streams, is_cst_type | |||
| from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum, TargetTypeEnum, ConditionContext | |||
| from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum, ConditionContext | |||
| from mindinsight.debugger.conditionmgr.recommender import get_basic_node_info | |||
| from mindinsight.debugger.stream_handler.watchpoint_handler import validate_watch_condition | |||
| @@ -84,10 +84,7 @@ class WatchpointOperator: | |||
| raise DebuggerConditionUnavailableError( | |||
| "Failed to create watchpoint as the condition is not available.") | |||
| if condition.supported_target_type in [TargetTypeEnum.ACTIVATION, TargetTypeEnum.GRADIENT, | |||
| TargetTypeEnum.WEIGHT]: | |||
| watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy() | |||
| watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy() | |||
| watchpoint_stream = self._watchpoint_stream | |||
| watch_point_id = watchpoint_stream.create_watchpoint( | |||
| self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) | |||
| @@ -1,31 +1 @@ | |||
| { | |||
| "nodes": [ | |||
| { | |||
| "name": "Default", | |||
| "type": "name_scope", | |||
| "nodes": [ | |||
| { | |||
| "name": "Default/optimizer-Momentum", | |||
| "type": "name_scope", | |||
| "nodes": [ | |||
| { | |||
| "name": "Default/optimizer-Momentum/Parameter[18]_7", | |||
| "type": "aggregation_scope", | |||
| "nodes": [ | |||
| { | |||
| "name": "Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias", | |||
| "type": "Parameter", | |||
| "nodes": [], | |||
| "watched": 0 | |||
| } | |||
| ], | |||
| "watched": 0 | |||
| } | |||
| ], | |||
| "watched": 0 | |||
| } | |||
| ], | |||
| "watched": 0 | |||
| } | |||
| ] | |||
| } | |||
| {"nodes": [{"name": "Default", "type": "name_scope", "nodes": [{"name": "Default/optimizer-Momentum", "type": "name_scope", "nodes": [{"name": "Default/optimizer-Momentum/Parameter[18]_7", "type": "aggregation_scope", "nodes": [{"name": "Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias", "type": "Parameter", "nodes": [], "watched": 2}], "watched": 2}], "watched": 2}], "watched": 2}]} | |||
| @@ -169,7 +169,7 @@ class TestAscendDebugger: | |||
| url = 'update-watchpoint' | |||
| body_data = {'watch_point_id': watch_point_id, | |||
| 'watch_nodes': [leaf_node_name], | |||
| 'mode': 0} | |||
| 'mode': 1} | |||
| get_request_result(app_client, url, body_data) | |||
| # get updated nodes | |||
| url = 'search' | |||
| @@ -328,7 +328,7 @@ class TestAscendDebugger: | |||
| 'watch_nodes': ['Default']}, True), | |||
| ('update-watchpoint', | |||
| {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'], | |||
| 'mode': 0}, True), | |||
| 'mode': 1}, True), | |||
| ('update-watchpoint', | |||
| {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'], | |||
| 'mode': 1}, True), | |||
| @@ -438,7 +438,7 @@ class TestGPUDebugger: | |||
| 'watch_nodes': ['Default/TransData-op99']}, True), | |||
| ('update-watchpoint', | |||
| {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'], | |||
| 'mode': 0}, True), | |||
| 'mode': 1}, True), | |||
| ('update-watchpoint', | |||
| {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'], | |||
| 'mode': 1}, True), | |||
| @@ -450,9 +450,9 @@ class TestGPUDebugger: | |||
| ], True), | |||
| ('update-watchpoint', | |||
| [{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'], | |||
| 'mode': 0}, | |||
| 'mode': 1}, | |||
| {'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'], | |||
| 'mode': 1} | |||
| 'mode': 0} | |||
| ], True), | |||
| ('delete-watchpoint', {'watch_point_id': 1}, True) | |||
| ]) | |||
| @@ -32,6 +32,7 @@ from mindinsight.debugger.common.utils import Streams | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||
| from mindinsight.debugger.debugger_server import grpc_server_base | |||
| from mindinsight.debugger.stream_operator import watchpoint_operator | |||
| from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | |||
| TensorHandler | |||
| from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | |||
| @@ -196,6 +197,7 @@ class TestDebuggerServer: | |||
| @mock.patch.object(MetadataHandler, 'backend', 'GPU') | |||
| @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) | |||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') | |||
| @mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock()) | |||
| @mock.patch.object(WatchpointHandler, 'create_watchpoint') | |||
| def test_create_watchpoint(self, *args): | |||
| """Test create watchpoint.""" | |||