Browse Source

!740 debugger: added ut for watchpoint

Merge pull request !740 from zhangyunshu/zys_utst_new
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f107f45e71
6 changed files with 170 additions and 57 deletions
  1. +47
    -3
      tests/ut/debugger/configurations.py
  2. +33
    -0
      tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json
  3. +1
    -0
      tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_1.json
  4. +1
    -0
      tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_0.json
  5. +1
    -0
      tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_1.json
  6. +87
    -54
      tests/ut/debugger/stream_handler/test_watchpoint_handler.py

+ 47
- 3
tests/ut/debugger/configurations.py View File

@@ -18,17 +18,20 @@ import os

from google.protobuf import json_format

from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.debugger.common.utils import NodeBasicInfo
from mindinsight.debugger.proto import ms_graph_pb2
from mindinsight.debugger.stream_handler.graph_handler import GraphHandler
from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHitHandler

graph_proto_file = os.path.join(
GRAPH_PROTO_FILE = os.path.join(
os.path.dirname(__file__), '../../utils/resource/graph_pb/lenet.pb'
)


def get_graph_proto():
"""Get graph proto."""
with open(graph_proto_file, 'rb') as f:
with open(GRAPH_PROTO_FILE, 'rb') as f:
content = f.read()

graph = ms_graph_pb2.GraphProto()
@@ -38,7 +41,7 @@ def get_graph_proto():


def init_graph_handler():
"""Init graph proto."""
"""Init GraphHandler."""
graph = get_graph_proto()
graph_handler = GraphHandler()
graph_handler.put(graph)
@@ -46,6 +49,47 @@ def init_graph_handler():
return graph_handler


def init_watchpoint_hit_handler(value):
"""Init WatchpointHitHandler."""
wph_handler = WatchpointHitHandler()
wph_handler.put(value)

return wph_handler


def get_node_basic_infos(node_names):
"""Get node info according to node names."""
if not node_names:
return []
graph_stream = init_graph_handler()
node_infos = []
for node_name in node_names:
node_type = graph_stream.get_node_type(node_name)
if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value:
sub_nodes = graph_stream.get_nodes_by_scope(node_name)
sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
for node in sub_nodes]
node_infos.extend(sub_infos)
full_name = graph_stream.get_full_name(node_name)
node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type))
return node_infos


def get_watch_nodes_by_search(watch_nodes):
"""Get watched leaf nodes by search name."""
watched_leaf_nodes = []
graph_stream = init_graph_handler()
for search_name in watch_nodes:
search_nodes = graph_stream.get_searched_node_list()
search_node_names = [
NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
for node in search_nodes
if node.name.startswith(search_name)]
watched_leaf_nodes.extend(search_node_names)

return watched_leaf_nodes


def mock_tensor_proto():
"""Mock tensor proto."""
tensor_dict = {


+ 33
- 0
tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json View File

@@ -0,0 +1,33 @@
[
{
"watchCondition": {
"condition": "inf"
},
"id": 1
},
{
"watchCondition": {
"condition": "inf"
},
"id": 2,
"watchNodes": [
{
"nodeName": "Default",
"nodeType": "scope"
}
]
},
{
"watchCondition": {
"condition": "max_gt",
"value": 1.0
},
"id": 3,
"watchNodes": [
{
"nodeName": "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92",
"nodeType": "leaf"
}
]
}
]

+ 1
- 0
tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_1.json View File

@@ -0,0 +1 @@
[{"id": 1, "watch_condition": {"condition": "INF"}}, {"id": 2, "watch_condition": {"condition": "INF"}}, {"id": 3, "watch_condition": {"condition": "MAX_GT", "param": 1}}]

+ 1
- 0
tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_0.json View File

@@ -0,0 +1 @@
{"watch_point_hits": [{"node_name": "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92", "watch_points": [{"id": 1, "watch_condition": {"condition": "MAX_GT", "param": 1}}]}]}

+ 1
- 0
tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_1.json View File

@@ -0,0 +1 @@
null

+ 87
- 54
tests/ut/debugger/stream_handler/test_watchpoint_handler.py View File

@@ -17,6 +17,7 @@ import json
import os
from unittest import mock, TestCase

from google.protobuf import json_format
import pytest

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
@@ -27,7 +28,9 @@ from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHan
WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params

from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \
mock_tensor_history
mock_tensor_history, get_node_basic_infos, get_watch_nodes_by_search, \
init_watchpoint_hit_handler
from tests.utils.tools import compare_result_with_file


class TestWatchpointHandler:
@@ -36,6 +39,8 @@ class TestWatchpointHandler:
def setup_class(cls):
"""Init WatchpointHandler for watchpoint unittest."""
cls.handler = WatchpointHandler()
cls.results_dir = os.path.join(os.path.dirname(__file__),
'../expected_results/watchpoint')
cls.graph_results_dir = os.path.join(os.path.dirname(__file__),
'../expected_results/graph')
cls.graph_stream = init_graph_handler()
@@ -47,26 +52,64 @@ class TestWatchpointHandler:
({'condition': 'MAX_GT', 'param': 1},
["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], None, 3)
])
@mock.patch.object(Watchpoint, 'add_nodes')
def test_create_watchpoint(self, mock_add_nodes, watch_condition, watch_nodes,
def test_create_watchpoint(self, watch_condition, watch_nodes,
watch_point_id, expect_new_id):
"""Test create_watchpoint."""
mock_add_nodes.return_value = None
watch_nodes = get_node_basic_infos(watch_nodes)
watch_point_id = self.handler.create_watchpoint(watch_condition, watch_nodes, watch_point_id)
assert watch_point_id == expect_new_id

@pytest.mark.parametrize("filter_condition", [True, False])
@mock.patch.object(Watchpoint, 'get_set_cmd')
@mock.patch.object(Watchpoint, 'get_watch_condition_info')
def test_get(self, mock_get_wp, mock_get_cmd, filter_condition):
"""Test get."""
mock_get_wp.return_value = None
mock_get_cmd.return_value = None
@pytest.mark.parametrize(
"watch_point_id, watch_nodes, watched, expect_updated_id", [
(3, ["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], 1, 3),
(3, [], 1, 3)
])
def test_update_watchpoint_add(self, watch_point_id, watch_nodes, watched, expect_updated_id):
"""Test update_watchpoint on addition."""
watch_nodes = get_node_basic_infos(watch_nodes)
with TestCase().assertLogs(logger=log, level='DEBUG') as log_content:
self.handler.update_watchpoint(watch_point_id, watch_nodes, watched)
TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.",
log_content.output)

@pytest.mark.parametrize(
"watch_point_id, watch_nodes, watched, expect_updated_id", [
(2, ["Default"], 0, 2),
(2, [], 0, 2),
])
def test_update_watchpoint_delete(self, watch_point_id, watch_nodes, watched, expect_updated_id):
"""Test update_watchpoint on deletion."""
watch_nodes = get_watch_nodes_by_search(watch_nodes)
with TestCase().assertLogs(logger=log, level='DEBUG') as log_content:
self.handler.get(filter_condition)
TestCase().assertIn(f"DEBUG:debugger.debugger:get the watch points with filter_condition:{filter_condition}",
self.handler.update_watchpoint(watch_point_id, watch_nodes, watched)
TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.",
log_content.output)

@pytest.mark.parametrize("filter_condition, result_file", [
(True, 'watchpoint_handler_get_0.json')
])
def test_get_filter_true(self, filter_condition, result_file):
"""Test get with filter_condition is True."""
file_path = os.path.join(self.results_dir, result_file)
with open(file_path, 'r') as f:
contents = json.load(f)

reply = self.handler.get(filter_condition)
protos = reply.get('watch_points')
for proto in protos:
msg_dict = json_format.MessageToDict(proto)
assert msg_dict in contents

@pytest.mark.parametrize("filter_condition, result_file", [
(False, 'watchpoint_handler_get_1.json')
])
def test_get_filter_false(self, filter_condition, result_file):
"""Test get with filer_condition is False."""
file_path = os.path.join(self.results_dir, result_file)
reply = self.handler.get(filter_condition)
watch_points = reply.get('watch_points')
compare_result_with_file(watch_points, file_path)

def test_get_watchpoint_by_id_except(self):
"""Test get_watchpoint_by_id."""
watchpoint_id = 4
@@ -78,32 +121,12 @@ class TestWatchpointHandler:
@pytest.mark.parametrize("graph_file, watch_point_id", [
('graph_handler_get_3_single_node.json', 4)
])
@mock.patch.object(WatchpointHandler, '_set_watch_status_recursively')
def test_set_watch_nodes(self, mock_set_recur, graph_file, watch_point_id):
def test_set_watch_nodes(self, graph_file, watch_point_id):
"""Test set_watch_nodes."""
path = os.path.join(self.graph_results_dir, graph_file)
with open(path, 'r') as f:
graph = json.load(f)
instance = mock_set_recur.return_value
self.handler.set_watch_nodes(graph, self.graph_stream, watch_point_id)
assert instance.iscalled()

@pytest.mark.parametrize(
"watch_point_id, watch_nodes, watched, expect_updated_id", [
(2, ["Default"], 0, 2),
(3, ["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], 1, 3)
])
@mock.patch.object(Watchpoint, 'remove_nodes')
@mock.patch.object(Watchpoint, 'add_nodes')
def test_update_watchpoint(self, mock_add_nodes, mock_remove_nodes, watch_point_id, watch_nodes,
watched, expect_updated_id):
"""Test update_watchpoint."""
mock_add_nodes.return_value = None
mock_remove_nodes.return_value = None
with TestCase().assertLogs(logger=log, level='DEBUG') as log_content:
self.handler.update_watchpoint(watch_point_id, watch_nodes, watched)
TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.",
log_content.output)

@pytest.mark.parametrize(
"watch_point_id, expect_deleted_ids", [
@@ -119,42 +142,52 @@ class TestWatchpointHandler:

class TestWatchpointHitHandler:
"""Test WatchpointHitHandler."""
watchpoint = Watchpoint(
watch_condition={'condition': 'MAX_GT', 'param': 1},
watchpoint_id=1
)
value = {
'tensor_proto': mock_tensor_proto(),
'watchpoint': watchpoint,
'node_name': 'Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92',
'finished': True,
'slot': 0
}

@classmethod
def setup_class(cls):
"""Setup."""
cls.handler = WatchpointHitHandler()
cls.tensor_proto = mock_tensor_proto()
cls.handler = init_watchpoint_hit_handler(cls.value)
cls.tensor_hist = mock_tensor_history()
cls.results_dir = os.path.join(os.path.dirname(__file__),
'../expected_results/watchpoint')

@mock.patch('mindinsight.debugger.stream_cache.watchpoint.WatchpointHit')
def test_put(self, mock_hit):
"""Test put."""
value = {
'tensor_proto': self.tensor_proto,
'watchpoint': {'id': 1, 'watch_condition': {'condition': 'INF'}},
'node_name': 'Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92'
}
mock_hit.return_value = mock.MagicMock(
tensor_proto=value.get('tensor_proto'),
watchpoint=value.get('watchpoint'),
node_name=value.get('node_name')
tensor_proto=self.value.get('tensor_proto'),
watchpoint=self.value.get('watchpoint'),
node_name=self.value.get('node_name')
)
self.handler.put(value)
WatchpointHitHandler().put(self.value)

@pytest.mark.parametrize("filter_condition", [
None, "Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190"
@pytest.mark.parametrize("filter_condition, result_file", [
(None, "watchpoint_hit_handler_get_0.json"),
("Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190",
"watchpoint_hit_handler_get_1.json")
])
@mock.patch.object(WatchpointHitHandler, 'get_watchpoint_hits')
def test_get(self, mock_get, filter_condition):
def test_get(self, filter_condition, result_file):
"""Test get."""
mock_get.return_value = {'watch_point_hits': []}
self.handler.get(filter_condition)
reply = self.handler.get(filter_condition)
file_path = os.path.join(self.results_dir, result_file)
compare_result_with_file(reply, file_path)

@mock.patch.object(WatchpointHitHandler, '_is_tensor_hit')
def test_update_tensor_history(self, mock_hit):
def test_update_tensor_history(self):
"""Test update_tensor_history."""
mock_hit.side_effect = [True, False]
self.handler.update_tensor_history(self.tensor_hist)
for tensor_info in self.tensor_hist.get('tensor_history'):
assert tensor_info['is_hit'] is False


def test_validate_watch_condition_type_error():


Loading…
Cancel
Save