|
|
@@ -17,6 +17,7 @@ import json |
|
|
import os |
|
|
import os |
|
|
from unittest import mock, TestCase |
|
|
from unittest import mock, TestCase |
|
|
|
|
|
|
|
|
|
|
|
from google.protobuf import json_format |
|
|
import pytest |
|
|
import pytest |
|
|
|
|
|
|
|
|
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ |
|
|
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ |
|
|
@@ -27,7 +28,9 @@ from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHan |
|
|
WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params |
|
|
WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params |
|
|
|
|
|
|
|
|
from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \ |
|
|
from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \ |
|
|
mock_tensor_history |
|
|
|
|
|
|
|
|
mock_tensor_history, get_node_basic_infos, get_watch_nodes_by_search, \ |
|
|
|
|
|
init_watchpoint_hit_handler |
|
|
|
|
|
from tests.utils.tools import compare_result_with_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestWatchpointHandler: |
|
|
class TestWatchpointHandler: |
|
|
@@ -36,6 +39,8 @@ class TestWatchpointHandler: |
|
|
def setup_class(cls): |
|
|
def setup_class(cls): |
|
|
"""Init WatchpointHandler for watchpoint unittest.""" |
|
|
"""Init WatchpointHandler for watchpoint unittest.""" |
|
|
cls.handler = WatchpointHandler() |
|
|
cls.handler = WatchpointHandler() |
|
|
|
|
|
cls.results_dir = os.path.join(os.path.dirname(__file__), |
|
|
|
|
|
'../expected_results/watchpoint') |
|
|
cls.graph_results_dir = os.path.join(os.path.dirname(__file__), |
|
|
cls.graph_results_dir = os.path.join(os.path.dirname(__file__), |
|
|
'../expected_results/graph') |
|
|
'../expected_results/graph') |
|
|
cls.graph_stream = init_graph_handler() |
|
|
cls.graph_stream = init_graph_handler() |
|
|
@@ -47,26 +52,64 @@ class TestWatchpointHandler: |
|
|
({'condition': 'MAX_GT', 'param': 1}, |
|
|
({'condition': 'MAX_GT', 'param': 1}, |
|
|
["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], None, 3) |
|
|
["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], None, 3) |
|
|
]) |
|
|
]) |
|
|
@mock.patch.object(Watchpoint, 'add_nodes') |
|
|
|
|
|
def test_create_watchpoint(self, mock_add_nodes, watch_condition, watch_nodes, |
|
|
|
|
|
|
|
|
def test_create_watchpoint(self, watch_condition, watch_nodes, |
|
|
watch_point_id, expect_new_id): |
|
|
watch_point_id, expect_new_id): |
|
|
"""Test create_watchpoint.""" |
|
|
"""Test create_watchpoint.""" |
|
|
mock_add_nodes.return_value = None |
|
|
|
|
|
|
|
|
watch_nodes = get_node_basic_infos(watch_nodes) |
|
|
watch_point_id = self.handler.create_watchpoint(watch_condition, watch_nodes, watch_point_id) |
|
|
watch_point_id = self.handler.create_watchpoint(watch_condition, watch_nodes, watch_point_id) |
|
|
assert watch_point_id == expect_new_id |
|
|
assert watch_point_id == expect_new_id |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("filter_condition", [True, False]) |
|
|
|
|
|
@mock.patch.object(Watchpoint, 'get_set_cmd') |
|
|
|
|
|
@mock.patch.object(Watchpoint, 'get_watch_condition_info') |
|
|
|
|
|
def test_get(self, mock_get_wp, mock_get_cmd, filter_condition): |
|
|
|
|
|
"""Test get.""" |
|
|
|
|
|
mock_get_wp.return_value = None |
|
|
|
|
|
mock_get_cmd.return_value = None |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
|
|
"watch_point_id, watch_nodes, watched, expect_updated_id", [ |
|
|
|
|
|
(3, ["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], 1, 3), |
|
|
|
|
|
(3, [], 1, 3) |
|
|
|
|
|
]) |
|
|
|
|
|
def test_update_watchpoint_add(self, watch_point_id, watch_nodes, watched, expect_updated_id): |
|
|
|
|
|
"""Test update_watchpoint on addition.""" |
|
|
|
|
|
watch_nodes = get_node_basic_infos(watch_nodes) |
|
|
|
|
|
with TestCase().assertLogs(logger=log, level='DEBUG') as log_content: |
|
|
|
|
|
self.handler.update_watchpoint(watch_point_id, watch_nodes, watched) |
|
|
|
|
|
TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.", |
|
|
|
|
|
log_content.output) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
|
|
"watch_point_id, watch_nodes, watched, expect_updated_id", [ |
|
|
|
|
|
(2, ["Default"], 0, 2), |
|
|
|
|
|
(2, [], 0, 2), |
|
|
|
|
|
]) |
|
|
|
|
|
def test_update_watchpoint_delete(self, watch_point_id, watch_nodes, watched, expect_updated_id): |
|
|
|
|
|
"""Test update_watchpoint on deletion.""" |
|
|
|
|
|
watch_nodes = get_watch_nodes_by_search(watch_nodes) |
|
|
with TestCase().assertLogs(logger=log, level='DEBUG') as log_content: |
|
|
with TestCase().assertLogs(logger=log, level='DEBUG') as log_content: |
|
|
self.handler.get(filter_condition) |
|
|
|
|
|
TestCase().assertIn(f"DEBUG:debugger.debugger:get the watch points with filter_condition:{filter_condition}", |
|
|
|
|
|
|
|
|
self.handler.update_watchpoint(watch_point_id, watch_nodes, watched) |
|
|
|
|
|
TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.", |
|
|
log_content.output) |
|
|
log_content.output) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("filter_condition, result_file", [ |
|
|
|
|
|
(True, 'watchpoint_handler_get_0.json') |
|
|
|
|
|
]) |
|
|
|
|
|
def test_get_filter_true(self, filter_condition, result_file): |
|
|
|
|
|
"""Test get with filter_condition is True.""" |
|
|
|
|
|
file_path = os.path.join(self.results_dir, result_file) |
|
|
|
|
|
with open(file_path, 'r') as f: |
|
|
|
|
|
contents = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
reply = self.handler.get(filter_condition) |
|
|
|
|
|
protos = reply.get('watch_points') |
|
|
|
|
|
for proto in protos: |
|
|
|
|
|
msg_dict = json_format.MessageToDict(proto) |
|
|
|
|
|
assert msg_dict in contents |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("filter_condition, result_file", [ |
|
|
|
|
|
(False, 'watchpoint_handler_get_1.json') |
|
|
|
|
|
]) |
|
|
|
|
|
def test_get_filter_false(self, filter_condition, result_file): |
|
|
|
|
|
"""Test get with filer_condition is False.""" |
|
|
|
|
|
file_path = os.path.join(self.results_dir, result_file) |
|
|
|
|
|
reply = self.handler.get(filter_condition) |
|
|
|
|
|
watch_points = reply.get('watch_points') |
|
|
|
|
|
compare_result_with_file(watch_points, file_path) |
|
|
|
|
|
|
|
|
def test_get_watchpoint_by_id_except(self): |
|
|
def test_get_watchpoint_by_id_except(self): |
|
|
"""Test get_watchpoint_by_id.""" |
|
|
"""Test get_watchpoint_by_id.""" |
|
|
watchpoint_id = 4 |
|
|
watchpoint_id = 4 |
|
|
@@ -78,32 +121,12 @@ class TestWatchpointHandler: |
|
|
@pytest.mark.parametrize("graph_file, watch_point_id", [ |
|
|
@pytest.mark.parametrize("graph_file, watch_point_id", [ |
|
|
('graph_handler_get_3_single_node.json', 4) |
|
|
('graph_handler_get_3_single_node.json', 4) |
|
|
]) |
|
|
]) |
|
|
@mock.patch.object(WatchpointHandler, '_set_watch_status_recursively') |
|
|
|
|
|
def test_set_watch_nodes(self, mock_set_recur, graph_file, watch_point_id): |
|
|
|
|
|
|
|
|
def test_set_watch_nodes(self, graph_file, watch_point_id): |
|
|
"""Test set_watch_nodes.""" |
|
|
"""Test set_watch_nodes.""" |
|
|
path = os.path.join(self.graph_results_dir, graph_file) |
|
|
path = os.path.join(self.graph_results_dir, graph_file) |
|
|
with open(path, 'r') as f: |
|
|
with open(path, 'r') as f: |
|
|
graph = json.load(f) |
|
|
graph = json.load(f) |
|
|
instance = mock_set_recur.return_value |
|
|
|
|
|
self.handler.set_watch_nodes(graph, self.graph_stream, watch_point_id) |
|
|
self.handler.set_watch_nodes(graph, self.graph_stream, watch_point_id) |
|
|
assert instance.iscalled() |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
|
|
"watch_point_id, watch_nodes, watched, expect_updated_id", [ |
|
|
|
|
|
(2, ["Default"], 0, 2), |
|
|
|
|
|
(3, ["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], 1, 3) |
|
|
|
|
|
]) |
|
|
|
|
|
@mock.patch.object(Watchpoint, 'remove_nodes') |
|
|
|
|
|
@mock.patch.object(Watchpoint, 'add_nodes') |
|
|
|
|
|
def test_update_watchpoint(self, mock_add_nodes, mock_remove_nodes, watch_point_id, watch_nodes, |
|
|
|
|
|
watched, expect_updated_id): |
|
|
|
|
|
"""Test update_watchpoint.""" |
|
|
|
|
|
mock_add_nodes.return_value = None |
|
|
|
|
|
mock_remove_nodes.return_value = None |
|
|
|
|
|
with TestCase().assertLogs(logger=log, level='DEBUG') as log_content: |
|
|
|
|
|
self.handler.update_watchpoint(watch_point_id, watch_nodes, watched) |
|
|
|
|
|
TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.", |
|
|
|
|
|
log_content.output) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
@pytest.mark.parametrize( |
|
|
"watch_point_id, expect_deleted_ids", [ |
|
|
"watch_point_id, expect_deleted_ids", [ |
|
|
@@ -119,42 +142,52 @@ class TestWatchpointHandler: |
|
|
|
|
|
|
|
|
class TestWatchpointHitHandler: |
|
|
class TestWatchpointHitHandler: |
|
|
"""Test WatchpointHitHandler.""" |
|
|
"""Test WatchpointHitHandler.""" |
|
|
|
|
|
watchpoint = Watchpoint( |
|
|
|
|
|
watch_condition={'condition': 'MAX_GT', 'param': 1}, |
|
|
|
|
|
watchpoint_id=1 |
|
|
|
|
|
) |
|
|
|
|
|
value = { |
|
|
|
|
|
'tensor_proto': mock_tensor_proto(), |
|
|
|
|
|
'watchpoint': watchpoint, |
|
|
|
|
|
'node_name': 'Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92', |
|
|
|
|
|
'finished': True, |
|
|
|
|
|
'slot': 0 |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
def setup_class(cls): |
|
|
def setup_class(cls): |
|
|
"""Setup.""" |
|
|
"""Setup.""" |
|
|
cls.handler = WatchpointHitHandler() |
|
|
|
|
|
cls.tensor_proto = mock_tensor_proto() |
|
|
|
|
|
|
|
|
cls.handler = init_watchpoint_hit_handler(cls.value) |
|
|
cls.tensor_hist = mock_tensor_history() |
|
|
cls.tensor_hist = mock_tensor_history() |
|
|
|
|
|
cls.results_dir = os.path.join(os.path.dirname(__file__), |
|
|
|
|
|
'../expected_results/watchpoint') |
|
|
|
|
|
|
|
|
@mock.patch('mindinsight.debugger.stream_cache.watchpoint.WatchpointHit') |
|
|
@mock.patch('mindinsight.debugger.stream_cache.watchpoint.WatchpointHit') |
|
|
def test_put(self, mock_hit): |
|
|
def test_put(self, mock_hit): |
|
|
"""Test put.""" |
|
|
"""Test put.""" |
|
|
value = { |
|
|
|
|
|
'tensor_proto': self.tensor_proto, |
|
|
|
|
|
'watchpoint': {'id': 1, 'watch_condition': {'condition': 'INF'}}, |
|
|
|
|
|
'node_name': 'Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92' |
|
|
|
|
|
} |
|
|
|
|
|
mock_hit.return_value = mock.MagicMock( |
|
|
mock_hit.return_value = mock.MagicMock( |
|
|
tensor_proto=value.get('tensor_proto'), |
|
|
|
|
|
watchpoint=value.get('watchpoint'), |
|
|
|
|
|
node_name=value.get('node_name') |
|
|
|
|
|
|
|
|
tensor_proto=self.value.get('tensor_proto'), |
|
|
|
|
|
watchpoint=self.value.get('watchpoint'), |
|
|
|
|
|
node_name=self.value.get('node_name') |
|
|
) |
|
|
) |
|
|
self.handler.put(value) |
|
|
|
|
|
|
|
|
WatchpointHitHandler().put(self.value) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("filter_condition", [ |
|
|
|
|
|
None, "Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190" |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("filter_condition, result_file", [ |
|
|
|
|
|
(None, "watchpoint_hit_handler_get_0.json"), |
|
|
|
|
|
("Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190", |
|
|
|
|
|
"watchpoint_hit_handler_get_1.json") |
|
|
]) |
|
|
]) |
|
|
@mock.patch.object(WatchpointHitHandler, 'get_watchpoint_hits') |
|
|
|
|
|
def test_get(self, mock_get, filter_condition): |
|
|
|
|
|
|
|
|
def test_get(self, filter_condition, result_file): |
|
|
"""Test get.""" |
|
|
"""Test get.""" |
|
|
mock_get.return_value = {'watch_point_hits': []} |
|
|
|
|
|
self.handler.get(filter_condition) |
|
|
|
|
|
|
|
|
reply = self.handler.get(filter_condition) |
|
|
|
|
|
file_path = os.path.join(self.results_dir, result_file) |
|
|
|
|
|
compare_result_with_file(reply, file_path) |
|
|
|
|
|
|
|
|
@mock.patch.object(WatchpointHitHandler, '_is_tensor_hit') |
|
|
|
|
|
def test_update_tensor_history(self, mock_hit): |
|
|
|
|
|
|
|
|
def test_update_tensor_history(self): |
|
|
"""Test update_tensor_history.""" |
|
|
"""Test update_tensor_history.""" |
|
|
mock_hit.side_effect = [True, False] |
|
|
|
|
|
self.handler.update_tensor_history(self.tensor_hist) |
|
|
self.handler.update_tensor_history(self.tensor_hist) |
|
|
|
|
|
for tensor_info in self.tensor_hist.get('tensor_history'): |
|
|
|
|
|
assert tensor_info['is_hit'] is False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_watch_condition_type_error(): |
|
|
def test_validate_watch_condition_type_error(): |
|
|
|