add "1", "0" for --enable-debugger add default node for all kind of watch conditionstags/v1.1.0
| @@ -20,11 +20,11 @@ from mindinsight.conf import settings | |||||
| from mindinsight.utils.hook import BaseHook | from mindinsight.utils.hook import BaseHook | ||||
| def str2bool(string): | |||||
| def enable_debugger_string(string): | |||||
| """Convert str to bool""" | """Convert str to bool""" | ||||
| if string.lower() == 'false': | |||||
| if string.lower() in ('false', '0'): | |||||
| return False | return False | ||||
| if string.lower() == 'true': | |||||
| if string.lower() in ('true', '1'): | |||||
| return True | return True | ||||
| raise ValueError | raise ValueError | ||||
| @@ -83,11 +83,11 @@ class Hook(BaseHook): | |||||
| """ | """ | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--enable-debugger', | '--enable-debugger', | ||||
| type=str2bool, | |||||
| type=enable_debugger_string, | |||||
| action=EnableDebuggerAction, | action=EnableDebuggerAction, | ||||
| default=False, | default=False, | ||||
| help=""" | help=""" | ||||
| Enable debugger or not. | |||||
| Enable debugger or not. The value can be True/False/1/0 (case insensitive). | |||||
| Default is False.""") | Default is False.""") | ||||
| parser.add_argument( | parser.add_argument( | ||||
| @@ -504,9 +504,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| } | } | ||||
| hit_params = {} | hit_params = {} | ||||
| for param in watchpoint_hit_proto.watch_condition.params: | for param in watchpoint_hit_proto.watch_condition.params: | ||||
| if param.actual_value is not None and param.name not in \ | |||||
| (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value, | |||||
| ParamNameEnum.RANGE_END_INCLUSIVE.value): | |||||
| if param.name not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value, | |||||
| ParamNameEnum.RANGE_END_INCLUSIVE.value) \ | |||||
| and watchpoint_hit_proto.error_code == 0: | |||||
| hit_params[param.name] = param.actual_value | hit_params[param.name] = param.actual_value | ||||
| for i, param in enumerate(watchpoint_hit['watchpoint'].condition['params']): | for i, param in enumerate(watchpoint_hit['watchpoint'].condition['params']): | ||||
| name = param['name'] | name = param['name'] | ||||
| @@ -514,8 +514,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = hit_params[name] | watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = hit_params[name] | ||||
| else: | else: | ||||
| watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = None | watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = None | ||||
| if watchpoint_hit_proto.error_code is not None: | |||||
| watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code | |||||
| watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code | |||||
| watchpoint_hits.append(watchpoint_hit) | watchpoint_hits.append(watchpoint_hit) | ||||
| self._received_hit = watchpoint_hits | self._received_hit = watchpoint_hits | ||||
| reply = get_ack_reply() | reply = get_ack_reply() | ||||
| @@ -21,7 +21,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue | |||||
| from mindinsight.debugger.common.log import LOGGER as log | from mindinsight.debugger.common.log import LOGGER as log | ||||
| from mindinsight.debugger.common.utils import ServerStatus, \ | from mindinsight.debugger.common.utils import ServerStatus, \ | ||||
| Streams, is_cst_type | Streams, is_cst_type | ||||
| from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum, TargetTypeEnum, ConditionContext | |||||
| from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum, ConditionContext | |||||
| from mindinsight.debugger.conditionmgr.recommender import get_basic_node_info | from mindinsight.debugger.conditionmgr.recommender import get_basic_node_info | ||||
| from mindinsight.debugger.stream_handler.watchpoint_handler import validate_watch_condition | from mindinsight.debugger.stream_handler.watchpoint_handler import validate_watch_condition | ||||
| @@ -84,10 +84,7 @@ class WatchpointOperator: | |||||
| raise DebuggerConditionUnavailableError( | raise DebuggerConditionUnavailableError( | ||||
| "Failed to create watchpoint as the condition is not available.") | "Failed to create watchpoint as the condition is not available.") | ||||
| if condition.supported_target_type in [TargetTypeEnum.ACTIVATION, TargetTypeEnum.GRADIENT, | |||||
| TargetTypeEnum.WEIGHT]: | |||||
| watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy() | |||||
| watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy() | |||||
| watchpoint_stream = self._watchpoint_stream | watchpoint_stream = self._watchpoint_stream | ||||
| watch_point_id = watchpoint_stream.create_watchpoint( | watch_point_id = watchpoint_stream.create_watchpoint( | ||||
| self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) | self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) | ||||
| @@ -1,31 +1 @@ | |||||
| { | |||||
| "nodes": [ | |||||
| { | |||||
| "name": "Default", | |||||
| "type": "name_scope", | |||||
| "nodes": [ | |||||
| { | |||||
| "name": "Default/optimizer-Momentum", | |||||
| "type": "name_scope", | |||||
| "nodes": [ | |||||
| { | |||||
| "name": "Default/optimizer-Momentum/Parameter[18]_7", | |||||
| "type": "aggregation_scope", | |||||
| "nodes": [ | |||||
| { | |||||
| "name": "Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias", | |||||
| "type": "Parameter", | |||||
| "nodes": [], | |||||
| "watched": 0 | |||||
| } | |||||
| ], | |||||
| "watched": 0 | |||||
| } | |||||
| ], | |||||
| "watched": 0 | |||||
| } | |||||
| ], | |||||
| "watched": 0 | |||||
| } | |||||
| ] | |||||
| } | |||||
| {"nodes": [{"name": "Default", "type": "name_scope", "nodes": [{"name": "Default/optimizer-Momentum", "type": "name_scope", "nodes": [{"name": "Default/optimizer-Momentum/Parameter[18]_7", "type": "aggregation_scope", "nodes": [{"name": "Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias", "type": "Parameter", "nodes": [], "watched": 2}], "watched": 2}], "watched": 2}], "watched": 2}]} | |||||
| @@ -169,7 +169,7 @@ class TestAscendDebugger: | |||||
| url = 'update-watchpoint' | url = 'update-watchpoint' | ||||
| body_data = {'watch_point_id': watch_point_id, | body_data = {'watch_point_id': watch_point_id, | ||||
| 'watch_nodes': [leaf_node_name], | 'watch_nodes': [leaf_node_name], | ||||
| 'mode': 0} | |||||
| 'mode': 1} | |||||
| get_request_result(app_client, url, body_data) | get_request_result(app_client, url, body_data) | ||||
| # get updated nodes | # get updated nodes | ||||
| url = 'search' | url = 'search' | ||||
| @@ -328,7 +328,7 @@ class TestAscendDebugger: | |||||
| 'watch_nodes': ['Default']}, True), | 'watch_nodes': ['Default']}, True), | ||||
| ('update-watchpoint', | ('update-watchpoint', | ||||
| {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'], | {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'], | ||||
| 'mode': 0}, True), | |||||
| 'mode': 1}, True), | |||||
| ('update-watchpoint', | ('update-watchpoint', | ||||
| {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'], | {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'], | ||||
| 'mode': 1}, True), | 'mode': 1}, True), | ||||
| @@ -438,7 +438,7 @@ class TestGPUDebugger: | |||||
| 'watch_nodes': ['Default/TransData-op99']}, True), | 'watch_nodes': ['Default/TransData-op99']}, True), | ||||
| ('update-watchpoint', | ('update-watchpoint', | ||||
| {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'], | {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'], | ||||
| 'mode': 0}, True), | |||||
| 'mode': 1}, True), | |||||
| ('update-watchpoint', | ('update-watchpoint', | ||||
| {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'], | {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'], | ||||
| 'mode': 1}, True), | 'mode': 1}, True), | ||||
| @@ -450,9 +450,9 @@ class TestGPUDebugger: | |||||
| ], True), | ], True), | ||||
| ('update-watchpoint', | ('update-watchpoint', | ||||
| [{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'], | [{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'], | ||||
| 'mode': 0}, | |||||
| 'mode': 1}, | |||||
| {'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'], | {'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'], | ||||
| 'mode': 1} | |||||
| 'mode': 0} | |||||
| ], True), | ], True), | ||||
| ('delete-watchpoint', {'watch_point_id': 1}, True) | ('delete-watchpoint', {'watch_point_id': 1}, True) | ||||
| ]) | ]) | ||||
| @@ -32,6 +32,7 @@ from mindinsight.debugger.common.utils import Streams | |||||
| from mindinsight.debugger.debugger_cache import DebuggerCache | from mindinsight.debugger.debugger_cache import DebuggerCache | ||||
| from mindinsight.debugger.debugger_server import DebuggerServer | from mindinsight.debugger.debugger_server import DebuggerServer | ||||
| from mindinsight.debugger.debugger_server import grpc_server_base | from mindinsight.debugger.debugger_server import grpc_server_base | ||||
| from mindinsight.debugger.stream_operator import watchpoint_operator | |||||
| from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | ||||
| TensorHandler | TensorHandler | ||||
| from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | ||||
| @@ -196,6 +197,7 @@ class TestDebuggerServer: | |||||
| @mock.patch.object(MetadataHandler, 'backend', 'GPU') | @mock.patch.object(MetadataHandler, 'backend', 'GPU') | ||||
| @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) | @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) | ||||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') | @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') | ||||
| @mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock()) | |||||
| @mock.patch.object(WatchpointHandler, 'create_watchpoint') | @mock.patch.object(WatchpointHandler, 'create_watchpoint') | ||||
| def test_create_watchpoint(self, *args): | def test_create_watchpoint(self, *args): | ||||
| """Test create watchpoint.""" | """Test create watchpoint.""" | ||||