|
|
|
@@ -25,19 +25,17 @@ class TestTensorHandler: |
|
|
|
"""Test TensorHandler.""" |
|
|
|
|
|
|
|
def setup_method(self): |
|
|
|
"""Setup method for each test case.""" |
|
|
|
self.tensor_handler = TensorHandler() |
|
|
|
|
|
|
|
@mock.patch.object(TensorHandler, '_get_tensor') |
|
|
|
@mock.patch.object(log, "error") |
|
|
|
@pytest.mark.parametrize("filter_condition", {}) |
|
|
|
def test_get(self, mock_get_tensor, mock_error, filter_condition): |
|
|
|
""" |
|
|
|
Test get full tensor value. |
|
|
|
""" |
|
|
|
def test_get(self, mock_get_tensor, mock_error): |
|
|
|
"""Test get full tensor value.""" |
|
|
|
mock_get_tensor.return_value = None |
|
|
|
mock_error.return_value = None |
|
|
|
with pytest.raises(DebuggerParamValueError) as ex: |
|
|
|
self.tensor_handler.get(filter_condition) |
|
|
|
self.tensor_handler.get({}) |
|
|
|
assert "No tensor named {}".format(None) in str(ex.value) |
|
|
|
|
|
|
|
def test_get_tensor_value_by_name_none(self): |
|
|
|
@@ -48,9 +46,7 @@ class TestTensorHandler: |
|
|
|
@mock.patch.object(log, "error") |
|
|
|
@pytest.mark.parametrize("tensor_name", "name") |
|
|
|
def test_get_tensors_diff_error(self, mock_error, tensor_name): |
|
|
|
""" |
|
|
|
Test get_tensors_diff. |
|
|
|
""" |
|
|
|
"""Test get_tensors_diff.""" |
|
|
|
mock_error.return_value = None |
|
|
|
with pytest.raises(DebuggerParamValueError) as ex: |
|
|
|
self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}) |
|
|
|
|