Browse Source

fix argument and add default node

add "1", "0" for --enable-debugger
add default node for all kind of watch conditions
tags/v1.1.0
jiangshuqiang 4 years ago
parent
commit
632683bd01
7 changed files with 20 additions and 1788 deletions
  1. +5
    -5
      mindinsight/common/hook/debugger.py
  2. +4
    -5
      mindinsight/debugger/debugger_grpc_server.py
  3. +2
    -5
      mindinsight/debugger/stream_operator/watchpoint_operator.py
  4. +1
    -1737
      tests/st/func/debugger/expect_results/restful_results/retrieve_single_watchpoint_hit.json
  5. +1
    -31
      tests/st/func/debugger/expect_results/restful_results/search_unwatched_leaf_node.json
  6. +5
    -5
      tests/st/func/debugger/test_restful_api.py
  7. +2
    -0
      tests/ut/debugger/test_debugger_server.py

+ 5
- 5
mindinsight/common/hook/debugger.py View File

@@ -20,11 +20,11 @@ from mindinsight.conf import settings
from mindinsight.utils.hook import BaseHook
def str2bool(string):
def enable_debugger_string(string):
"""Convert str to bool"""
if string.lower() == 'false':
if string.lower() in ('false', '0'):
return False
if string.lower() == 'true':
if string.lower() in ('true', '1'):
return True
raise ValueError
@@ -83,11 +83,11 @@ class Hook(BaseHook):
"""
parser.add_argument(
'--enable-debugger',
type=str2bool,
type=enable_debugger_string,
action=EnableDebuggerAction,
default=False,
help="""
Enable debugger or not.
Enable debugger or not. The value can be True/False/1/0 (case insensitive).
Default is False.""")
parser.add_argument(


+ 4
- 5
mindinsight/debugger/debugger_grpc_server.py View File

@@ -504,9 +504,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
}
hit_params = {}
for param in watchpoint_hit_proto.watch_condition.params:
if param.actual_value is not None and param.name not in \
(ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value,
ParamNameEnum.RANGE_END_INCLUSIVE.value):
if param.name not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value,
ParamNameEnum.RANGE_END_INCLUSIVE.value) \
and watchpoint_hit_proto.error_code == 0:
hit_params[param.name] = param.actual_value
for i, param in enumerate(watchpoint_hit['watchpoint'].condition['params']):
name = param['name']
@@ -514,8 +514,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = hit_params[name]
else:
watchpoint_hit['watchpoint'].condition['params'][i]['actual_value'] = None
if watchpoint_hit_proto.error_code is not None:
watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code
watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code
watchpoint_hits.append(watchpoint_hit)
self._received_hit = watchpoint_hits
reply = get_ack_reply()


+ 2
- 5
mindinsight/debugger/stream_operator/watchpoint_operator.py View File

@@ -21,7 +21,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import ServerStatus, \
Streams, is_cst_type
from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum, TargetTypeEnum, ConditionContext
from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum, ConditionContext
from mindinsight.debugger.conditionmgr.recommender import get_basic_node_info
from mindinsight.debugger.stream_handler.watchpoint_handler import validate_watch_condition

@@ -84,10 +84,7 @@ class WatchpointOperator:
raise DebuggerConditionUnavailableError(
"Failed to create watchpoint as the condition is not available.")

if condition.supported_target_type in [TargetTypeEnum.ACTIVATION, TargetTypeEnum.GRADIENT,
TargetTypeEnum.WEIGHT]:
watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy()

watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy()
watchpoint_stream = self._watchpoint_stream
watch_point_id = watchpoint_stream.create_watchpoint(
self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id'))


+ 1
- 1737
tests/st/func/debugger/expect_results/restful_results/retrieve_single_watchpoint_hit.json
File diff suppressed because it is too large
View File


+ 1
- 31
tests/st/func/debugger/expect_results/restful_results/search_unwatched_leaf_node.json View File

@@ -1,31 +1 @@
{
"nodes": [
{
"name": "Default",
"type": "name_scope",
"nodes": [
{
"name": "Default/optimizer-Momentum",
"type": "name_scope",
"nodes": [
{
"name": "Default/optimizer-Momentum/Parameter[18]_7",
"type": "aggregation_scope",
"nodes": [
{
"name": "Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias",
"type": "Parameter",
"nodes": [],
"watched": 0
}
],
"watched": 0
}
],
"watched": 0
}
],
"watched": 0
}
]
}
{"nodes": [{"name": "Default", "type": "name_scope", "nodes": [{"name": "Default/optimizer-Momentum", "type": "name_scope", "nodes": [{"name": "Default/optimizer-Momentum/Parameter[18]_7", "type": "aggregation_scope", "nodes": [{"name": "Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias", "type": "Parameter", "nodes": [], "watched": 2}], "watched": 2}], "watched": 2}], "watched": 2}]}

+ 5
- 5
tests/st/func/debugger/test_restful_api.py View File

@@ -169,7 +169,7 @@ class TestAscendDebugger:
url = 'update-watchpoint'
body_data = {'watch_point_id': watch_point_id,
'watch_nodes': [leaf_node_name],
'mode': 0}
'mode': 1}
get_request_result(app_client, url, body_data)
# get updated nodes
url = 'search'
@@ -328,7 +328,7 @@ class TestAscendDebugger:
'watch_nodes': ['Default']}, True),
('update-watchpoint',
{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
'mode': 0}, True),
'mode': 1}, True),
('update-watchpoint',
{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
'mode': 1}, True),
@@ -438,7 +438,7 @@ class TestGPUDebugger:
'watch_nodes': ['Default/TransData-op99']}, True),
('update-watchpoint',
{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
'mode': 0}, True),
'mode': 1}, True),
('update-watchpoint',
{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
'mode': 1}, True),
@@ -450,9 +450,9 @@ class TestGPUDebugger:
], True),
('update-watchpoint',
[{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
'mode': 0},
'mode': 1},
{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
'mode': 1}
'mode': 0}
], True),
('delete-watchpoint', {'watch_point_id': 1}, True)
])


+ 2
- 0
tests/ut/debugger/test_debugger_server.py View File

@@ -32,6 +32,7 @@ from mindinsight.debugger.common.utils import Streams
from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_server import DebuggerServer
from mindinsight.debugger.debugger_server import grpc_server_base
from mindinsight.debugger.stream_operator import watchpoint_operator
from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \
TensorHandler
from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history
@@ -196,6 +197,7 @@ class TestDebuggerServer:
@mock.patch.object(MetadataHandler, 'backend', 'GPU')
@mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock())
@mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope')
@mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock())
@mock.patch.object(WatchpointHandler, 'create_watchpoint')
def test_create_watchpoint(self, *args):
"""Test create watchpoint."""


Loading…
Cancel
Save