Browse Source

add ut for debugger

tags/v1.1.0
yelihua 5 years ago
parent
commit
5471d49342
9 changed files with 528 additions and 14 deletions
  1. +6
    -3
      mindinsight/debugger/common/exceptions/exceptions.py
  2. +7
    -9
      mindinsight/debugger/debugger_server.py
  3. +6
    -1
      tests/ut/debugger/__init__.py
  4. +14
    -0
      tests/ut/debugger/configurations.py
  5. +1
    -0
      tests/ut/debugger/expected_results/debugger_server/retrieve_all.json
  6. +1
    -0
      tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json
  7. +6
    -1
      tests/ut/debugger/stream_handler/test_watchpoint_handler.py
  8. +237
    -0
      tests/ut/debugger/test_debugger_grpc_server.py
  9. +250
    -0
      tests/ut/debugger/test_debugger_server.py

+ 6
- 3
mindinsight/debugger/common/exceptions/exceptions.py View File

@@ -78,7 +78,8 @@ class DebuggerCompareTensorError(MindInsightException):
def __init__(self, msg):
super(DebuggerCompareTensorError, self).__init__(
error=DebuggerErrors.COMPARE_TENSOR_ERROR,
message=DebuggerErrorMsg.COMPARE_TENSOR_ERROR.value.format(msg)
message=msg,
http_code=400
)


@@ -111,7 +112,8 @@ class DebuggerNodeNotInGraphError(MindInsightException):
err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}."
super(DebuggerNodeNotInGraphError, self).__init__(
error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR,
message=err_msg
message=err_msg,
http_code=400
)


@@ -120,5 +122,6 @@ class DebuggerGraphNotExistError(MindInsightException):
def __init__(self):
super(DebuggerGraphNotExistError, self).__init__(
error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR,
message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value
message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value,
http_code=400
)

+ 7
- 9
mindinsight/debugger/debugger_server.py View File

@@ -24,7 +24,8 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.utils.tools import to_float
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \
DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError
DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, \
DebuggerCompareTensorError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo
@@ -146,13 +147,10 @@ class DebuggerServer:
node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
tolerance = to_float(tolerance, 'tolerance')
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
if detail == 'data':
if node_type == NodeTypeEnum.PARAMETER.value:
reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance)
else:
raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type))
if node_type == NodeTypeEnum.PARAMETER.value:
reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance)
else:
raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail))
raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type))
return reply

def retrieve(self, mode, filter_condition=None):
@@ -177,8 +175,8 @@ class DebuggerServer:
# validate param <mode>
if mode not in mode_mapping.keys():
log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
"'watchpoint_hit', 'tensor'], but got %s.", mode_mapping)
raise DebuggerParamTypeError("Invalid mode.")
"'watchpoint_hit'], but got %s.", mode_mapping)
raise DebuggerParamValueError("Invalid mode.")
# validate backend status
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.state == ServerStatus.PENDING.value:


+ 6
- 1
tests/ut/debugger/__init__.py View File

@@ -12,4 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test for debugger module."""
"""
Function:
Unit test for debugger module.
Usage:
pytest tests/ut/debugger/
"""

+ 14
- 0
tests/ut/debugger/configurations.py View File

@@ -23,10 +23,12 @@ from mindinsight.debugger.common.utils import NodeBasicInfo
from mindinsight.debugger.proto import ms_graph_pb2
from mindinsight.debugger.stream_handler.graph_handler import GraphHandler
from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHitHandler
from tests.utils.tools import compare_result_with_file

GRAPH_PROTO_FILE = os.path.join(
os.path.dirname(__file__), '../../utils/resource/graph_pb/lenet.pb'
)
DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expected_results')


def get_graph_proto():
@@ -137,3 +139,15 @@ def mock_tensor_history():
}

return tensor_history


def compare_debugger_result_with_file(res, expect_file):
"""
Compare debugger result with file.

Args:
res (dict): The debugger result in dict type.
expect_file: The expected file name.
"""
real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, expect_file)
compare_result_with_file(res, real_path)

+ 1
- 0
tests/ut/debugger/expected_results/debugger_server/retrieve_all.json View File

@@ -0,0 +1 @@
{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": ""}, "graph": {}, "watch_points": []}

+ 1
- 0
tests/ut/debugger/expected_results/debugger_server/retrieve_tensor_history.json View File

@@ -0,0 +1 @@
{"tensor_history": [{"name": "Default/TransData-op99:0", "full_name": "Default/TransData-op99:0", "node_type": "TransData", "type": "output", "step": 0, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": "click to view"}, {"name": "Default/args0:0", "full_name": "Default/args0:0", "node_type": "Parameter", "type": "input", "step": 0, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": "click to view"}], "metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": ""}}

+ 6
- 1
tests/ut/debugger/stream_handler/test_watchpoint_handler.py View File

@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test WatchpointHandler."""
"""
Function:
Test query debugger watchpoint handler.
Usage:
pytest tests/ut/debugger
"""
import json
import os
from unittest import mock, TestCase


+ 237
- 0
tests/ut/debugger/test_debugger_grpc_server.py View File

@@ -0,0 +1,237 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test debugger grpc server.
Usage:
pytest tests/ut/debugger/test_debugger_grpc_server.py
"""
from unittest import mock
from unittest.mock import MagicMock

import numpy as np

from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus
from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit
from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType
from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \
WatchpointHandler
from tests.ut.debugger.configurations import GRAPH_PROTO_FILE


class MockDataGenerator:
"""Mocked Data generator."""

@staticmethod
def get_run_cmd(steps=0, level='step', node_name=''):
"""Get run command."""
event = get_ack_reply()
event.run_cmd.run_level = level
if level == 'node':
event.run_cmd.node_name = node_name
else:
event.run_cmd.run_steps = steps
return event

@staticmethod
def get_exit_cmd():
"""Get exit command."""
event = get_ack_reply()
event.exit = True
return event

@staticmethod
def get_set_cmd():
"""Get set command"""
event = get_ack_reply()
event.set_cmd.CopyFrom(SetCMD(id=1, watch_condition=1))
return event

@staticmethod
def get_view_cmd():
"""Get set command"""
view_event = get_ack_reply()
ms_tensor = view_event.view_cmd.tensors.add()
ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0'
event = {'view_cmd': view_event, 'node_name': 'mock_node_name'}
return event

@staticmethod
def get_graph_chunks():
"""Get graph chunks."""
chunk_size = 1024
with open(GRAPH_PROTO_FILE, 'rb') as file_handler:
content = file_handler.read()
chunks = [Chunk(buffer=content[0:chunk_size]), Chunk(buffer=content[chunk_size:])]
return chunks

@staticmethod
def get_tensors():
"""Get tensors."""
tensor_content = np.asarray([1, 2, 3, 4, 5, 6]).astype(np.float32).tobytes()
tensor_pre = TensorProto(
node_name='mock_node_name',
slot='0',
data_type=DataType.DT_FLOAT32,
dims=[2, 3],
tensor_content=tensor_content[:12],
finished=0
)
tensor_succ = TensorProto()
tensor_succ.CopyFrom(tensor_pre)
tensor_succ.tensor_content = tensor_content[12:]
tensor_succ.finished = 1
return [tensor_pre, tensor_succ]

@staticmethod
def get_watchpoint_hit():
"""Get watchpoint hit."""
res = WatchpointHit(id=1)
res.tensor.node_name = 'mock_node_name'
res.tensor.slot = '0'
return res


class TestDebuggerGrpcServer:
"""Test debugger grpc server."""

@classmethod
def setup_class(cls):
"""Initialize for test class."""
cls._server = None

def setup_method(self):
"""Initialize for each testcase."""
cache_store = DebuggerCache()
self._server = DebuggerGrpcServer(cache_store)

def test_waitcmd_with_pending_status(self):
"""Test wait command interface when status is pending."""
res = self._server.WaitCMD(MagicMock(), MagicMock())
assert res.status == EventReply.Status.FAILED

@mock.patch.object(WatchpointHitHandler, 'empty', False)
@mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command')
def test_waitcmd_with_old_command(self, *args):
"""Test wait command interface with old command."""
old_command = MockDataGenerator.get_run_cmd(steps=1)
args[0].return_value = old_command
setattr(self._server, '_status', ServerStatus.WAITING)
setattr(self._server, '_received_view_cmd', {'node_name': 'mock_node_name'})
setattr(self._server, '_received_hit', True)
res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock())
assert res == old_command

@mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None)
@mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command')
def test_waitcmd_with_next_command(self, *args):
"""Test wait for next command."""
old_command = MockDataGenerator.get_run_cmd(steps=1)
args[0].return_value = old_command
setattr(self._server, '_status', ServerStatus.WAITING)
res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock())
assert res == old_command

@mock.patch.object(DebuggerGrpcServer, '_deal_with_old_command', return_value=None)
@mock.patch.object(DebuggerGrpcServer, '_wait_for_next_command')
def test_waitcmd_with_next_command_is_none(self, *args):
"""Test wait command interface with next command is None."""
args[0].return_value = None
setattr(self._server, '_status', ServerStatus.RECEIVE_GRAPH)
res = self._server.WaitCMD(MagicMock(cur_step=1), MagicMock())
assert res == get_ack_reply(1)

@mock.patch.object(DebuggerCache, 'get_command', return_value=(0, None))
@mock.patch.object(DebuggerCache, 'has_command')
def test_deal_with_old_command_with_continue_steps(self, *args):
"""Test deal with old command with continue steps."""
args[0].side_effect = [True, False]
setattr(self._server, '_continue_steps', 1)
res = self._server._deal_with_old_command()
assert res == MockDataGenerator.get_run_cmd(steps=1)

@mock.patch.object(DebuggerCache, 'get_command')
@mock.patch.object(DebuggerCache, 'has_command', return_value=True)
def test_deal_with_old_command_with_exit_cmd(self, *args):
"""Test deal with exit command."""
cmd = MockDataGenerator.get_exit_cmd()
args[1].return_value = ('0', cmd)
res = self._server._deal_with_old_command()
assert res == cmd

@mock.patch.object(DebuggerCache, 'get_command')
@mock.patch.object(DebuggerCache, 'has_command', return_value=True)
def test_deal_with_old_command_with_view_cmd(self, *args):
"""Test deal with view command."""
cmd = MockDataGenerator.get_view_cmd()
args[1].return_value = ('0', cmd)
res = self._server._deal_with_old_command()
assert res == cmd.get('view_cmd')
expect_received_view_cmd = {'node_name': cmd.get('node_name'), 'wait_for_tensor': True}
assert getattr(self._server, '_received_view_cmd') == expect_received_view_cmd

@mock.patch.object(DebuggerCache, 'get_command')
def test_wait_for_run_command(self, *args):
"""Test wait for run command."""
cmd = MockDataGenerator.get_run_cmd(steps=2)
args[0].return_value = ('0', cmd)
setattr(self._server, '_status', ServerStatus.WAITING)
res = self._server._wait_for_next_command()
assert res == MockDataGenerator.get_run_cmd(steps=1)
assert getattr(self._server, '_continue_steps') == 1

@mock.patch.object(DebuggerCache, 'get_command')
def test_wait_for_pause_and_run_command(self, *args):
"""Test wait for run command."""
pause_cmd = MockDataGenerator.get_run_cmd(steps=0)
empty_view_cmd = MockDataGenerator.get_view_cmd()
empty_view_cmd.pop('node_name')
run_cmd = MockDataGenerator.get_run_cmd(steps=2)
args[0].side_effect = [('0', pause_cmd), ('0', empty_view_cmd), ('0', run_cmd)]
setattr(self._server, '_status', ServerStatus.WAITING)
res = self._server._wait_for_next_command()
assert res == run_cmd
assert getattr(self._server, '_continue_steps') == 1

def test_send_matadata(self):
"""Test SendMatadata interface."""
res = self._server.SendMetadata(MagicMock(training_done=False), MagicMock())
assert res == get_ack_reply()

def test_send_matadata_with_training_done(self):
"""Test SendMatadata interface."""
res = self._server.SendMetadata(MagicMock(training_done=True), MagicMock())
assert res == get_ack_reply()

def test_send_graph(self):
"""Test SendGraph interface."""
res = self._server.SendGraph(MockDataGenerator.get_graph_chunks(), MagicMock())
assert res == get_ack_reply()

def test_send_tensors(self):
"""Test SendTensors interface."""
res = self._server.SendTensors(MockDataGenerator.get_tensors(), MagicMock())
assert res == get_ack_reply()

@mock.patch.object(WatchpointHandler, 'get_watchpoint_by_id')
@mock.patch.object(GraphHandler, 'get_node_name_by_full_name')
def test_send_watchpoint_hit(self, *args):
"""Test SendWatchpointHits interface."""
args[0].side_effect = [None, 'mock_full_name']
watchpoint_hit = MockDataGenerator.get_watchpoint_hit()
res = self._server.SendWatchpointHits([watchpoint_hit, watchpoint_hit], MagicMock())
assert res == get_ack_reply()

+ 250
- 0
tests/ut/debugger/test_debugger_server.py View File

@@ -0,0 +1,250 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test debugger server.
Usage:
pytest tests/ut/debugger/test_debugger_server.py
"""
import signal
from threading import Thread
from unittest import mock
from unittest.mock import MagicMock

import grpc
import pytest

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError
from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_server import DebuggerServer
from mindinsight.debugger.debugger_server import grpc_server_base
from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \
TensorHandler
from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history


class TestDebuggerServer:
"""Test debugger server."""

@classmethod
def setup_class(cls):
"""Initialize for test class."""
cls._server = None

def setup_method(self):
"""Prepare debugger server object."""
self._server = DebuggerServer()

@mock.patch.object(signal, 'signal')
@mock.patch.object(Thread, 'join')
@mock.patch.object(Thread, 'start')
@mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server')
@mock.patch.object(grpc, 'server')
def test_stop_server(self, *args):
"""Test stop debugger server."""
mock_grpc_server_manager = MagicMock()
args[0].return_value = mock_grpc_server_manager
self._server.start()
self._server._stop_handler(MagicMock(), MagicMock())
assert self._server.back_server is not None
assert self._server.grpc_server_manager == mock_grpc_server_manager

@mock.patch.object(DebuggerCache, 'get_data')
def test_poll_data(self, *args):
"""Test poll data request."""
mock_data = {'pos': 'mock_data'}
args[0].return_value = mock_data
res = self._server.poll_data('0')
assert res == mock_data

def test_poll_data_with_exept(self):
"""Test poll data with wrong input."""
with pytest.raises(DebuggerParamValueError, match='Pos should be string.'):
self._server.poll_data(1)

@mock.patch.object(GraphHandler, 'search_nodes')
def test_search(self, *args):
"""Test search node."""
mock_graph = {'nodes': ['mock_nodes']}
args[0].return_value = mock_graph
res = self._server.search('mock_name')
assert res == mock_graph

def test_tensor_comparision_with_wrong_status(self):
"""Test tensor comparison with wrong status."""
with pytest.raises(
DebuggerCompareTensorError,
match='Failed to compare tensors as the MindSpore is not in waiting state.'):
self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]')

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(GraphHandler, 'get_node_type')
@mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
def test_tensor_comparision_with_wrong_type(self, *args):
"""Test tensor comparison with wrong type."""
args[1].return_value = 'name_scope'
with pytest.raises(DebuggerParamValueError, match='The node type must be parameter'):
self._server.tensor_comparisons(name='mock_node_name:0', shape='[:, :]')

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter')
@mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
@mock.patch.object(TensorHandler, 'get_tensors_diff')
def test_tensor_comparision(self, *args):
"""Test tensor comparison"""
mock_diff_res = {'tensor_value': {}}
args[0].return_value = mock_diff_res
res = self._server.tensor_comparisons('mock_node_name:0', '[:, :]')
assert res == mock_diff_res

def test_retrieve_with_pending(self):
"""Test retrieve request in pending status."""
res = self._server.retrieve(mode='all')
assert res.get('metadata', {}).get('state') == 'pending'

@mock.patch.object(MetadataHandler, 'state', 'waiting')
def test_retrieve_all(self):
"""Test retrieve request."""
res = self._server.retrieve(mode='all')
compare_debugger_result_with_file(res, 'debugger_server/retrieve_all.json')

def test_retrieve_with_invalid_mode(self):
"""Test retrieve with invalid mode."""
with pytest.raises(DebuggerParamValueError, match='Invalid mode.'):
self._server.retrieve(mode='invalid_mode')

@mock.patch.object(GraphHandler, 'get')
@mock.patch.object(GraphHandler, 'get_node_type', return_value='name_scope')
@mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_node_name')
def test_retrieve_node(self, *args):
"""Test retrieve node information."""
mock_graph = {'graph': {}}
args[2].return_value = mock_graph
res = self._server._retrieve_node({'name': 'mock_node_name'})
assert res == mock_graph

def test_retrieve_tensor_history_with_pending(self):
"""Test retrieve request in pending status."""
res = self._server.retrieve_tensor_history('mock_node_name')
assert res.get('metadata', {}).get('state') == 'pending'

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(GraphHandler, 'get_tensor_history')
@mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter')
def test_retrieve_tensor_history(self, *args):
"""Test retrieve tensor history."""
args[1].return_value = mock_tensor_history()
res = self._server.retrieve_tensor_history('mock_node_name')
compare_debugger_result_with_file(res, 'debugger_server/retrieve_tensor_history.json')

@mock.patch.object(GraphHandler, 'get_node_type')
def test_validate_leaf_name(self, *args):
"""Test validate leaf name."""
args[0].return_value = 'name_scope'
with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'):
self._server._validate_leaf_name(node_name='mock_node_name')

@mock.patch.object(TensorHandler, 'get')
@mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name')
def test_retrieve_tensor_value(self, *args):
"""Test retrieve tensor value."""
mock_tensor_value = {'tensor_value': {'name': 'mock_name:0'}}
args[0].return_value = ('Parameter', 'mock_node_name')
args[1].return_value = mock_tensor_value
res = self._server.retrieve_tensor_value('mock_name:0', 'data', '[:, :]')
assert res == mock_tensor_value

@mock.patch.object(WatchpointHandler, 'get')
def test_retrieve_watchpoints(self, *args):
"""Test retrieve watchpoints."""
mock_watchpoint = {'watch_points': {}}
args[0].return_value = mock_watchpoint
res = self._server._retrieve_watchpoint({})
assert res == mock_watchpoint

@mock.patch.object(DebuggerServer, '_retrieve_node')
def test_retrieve_watchpoint(self, *args):
"""Test retrieve single watchpoint."""
mock_watchpoint = {'nodes': {}}
args[0].return_value = mock_watchpoint
res = self._server._retrieve_watchpoint({'watch_point_id': 1})
assert res == mock_watchpoint

@mock.patch.object(DebuggerServer, '_validate_leaf_name')
@mock.patch.object(DebuggerServer, '_get_tensor_history')
@mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}})
def test_retrieve_watchpoint_hit(self, *args):
"""Test retrieve single watchpoint."""
args[1].return_value = {'tensor_history': {}}
res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True})
assert res == {'tensor_history': {}, 'graph': {}}

def test_create_watchpoint_with_wrong_state(self):
"""Test create watchpoint with wrong state."""
with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'):
self._server.create_watchpoint(watch_condition={'condition': 'INF'})

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name')
@mock.patch.object(GraphHandler, 'get_full_name', return_value='mock_full_name')
@mock.patch.object(GraphHandler, 'get_nodes_by_scope', return_value=[MagicMock()])
@mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope')
@mock.patch.object(WatchpointHandler, 'create_watchpoint')
def test_create_watchpoint(self, *args):
"""Test create watchpoint."""
args[0].return_value = 1
res = self._server.create_watchpoint({'condition': 'INF'}, ['watch_node_name'])
assert res == {'id': 1}

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(GraphHandler, 'get_searched_node_list')
@mock.patch.object(WatchpointHandler, 'validate_watchpoint_id')
@mock.patch.object(WatchpointHandler, 'update_watchpoint')
def test_update_watchpoint(self, *args):
"""Test update watchpoint."""
args[2].return_value = [MagicMock(name='seatch_name/op_name')]
res = self._server.update_watchpoint(
watch_point_id=1, watch_nodes=['search_name'], mode=1, name='search_name')
assert res == {}

def test_delete_watchpoint_with_wrong_state(self):
"""Test delete watchpoint with wrong state."""
with pytest.raises(DebuggerDeleteWatchPointError, match='Failed to delete watchpoint'):
self._server.delete_watchpoint(watch_point_id=1)

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(WatchpointHandler, 'delete_watchpoint')
def test_delete_watchpoint(self, *args):
"""Test delete watchpoint with wrong state."""
args[0].return_value = None
res = self._server.delete_watchpoint(1)
assert res == {}

@pytest.mark.parametrize('mode, cur_state, state', [
('continue', 'waiting', 'running'),
('pause', 'running', 'waiting'),
('terminate', 'waiting', 'pending')])
def test_control(self, mode, cur_state, state):
"""Test control request."""
with mock.patch.object(MetadataHandler, 'state', cur_state):
res = self._server.control({'mode': mode})
assert res == {'metadata': {'state': state}}

def test_construct_run_event(self):
"""Test construct run event."""
res = self._server._construct_run_event({'level': 'node'})
assert res.run_cmd == RunCMD(run_level='node', node_name='')

Loading…
Cancel
Save